From 75af3926715a957551351a29cf933becf24f3de4 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 14 Apr 2020 18:47:46 +0300 Subject: [PATCH 01/18] - memcpy fix + validation for CUDA: skip memcpu if length < 1 (#375) - Reset cached context after device affinity change Signed-off-by: raver119 --- .../nd4j/jita/concurrency/CudaAffinityManager.java | 4 ++++ .../java/org/nd4j/jita/handler/MemoryHandler.java | 2 ++ .../org/nd4j/jita/handler/impl/CudaZeroHandler.java | 12 ++++++++++++ 3 files changed, 18 insertions(+) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java index 356df88e0..415fa487f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java @@ -294,7 +294,11 @@ public class CudaAffinityManager extends BasicAffinityManager { @Override public void unsafeSetDevice(Integer deviceId) { + // actually set device NativeOpsHolder.getInstance().getDeviceNativeOps().setDevice(deviceId); + + // reset saved context, so it will be recreated on first call + AtomicAllocator.getInstance().getMemoryHandler().resetCachedContext(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java index abb919e8c..44d8e2042 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/MemoryHandler.java @@ -304,4 +304,6 @@ public interface MemoryHandler { boolean promoteObject(DataBuffer buffer); void relocateObject(DataBuffer buffer); + + void resetCachedContext(); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java index a8f3a0a3b..4c6e56bc9 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/handler/impl/CudaZeroHandler.java @@ -17,6 +17,8 @@ package org.nd4j.jita.handler.impl; import lombok.var; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.nativeblas.OpaqueLaunchContext; import org.nd4j.shade.guava.collect.HashBasedTable; import org.nd4j.shade.guava.collect.Table; @@ -325,6 +327,11 @@ public class CudaZeroHandler implements MemoryHandler { */ @Override public void memcpyAsync(DataBuffer dstBuffer, Pointer srcPointer, long length, long dstOffset) { + if (length < 1) + return; + + Preconditions.checkArgument(length <= (dstBuffer.length() * Nd4j.sizeOfDataType(dstBuffer.dataType())), "Length requested is bigger than target DataBuffer length"); + val point = ((BaseCudaDataBuffer) dstBuffer).getAllocationPoint(); CudaContext tContext = null; @@ -1041,6 +1048,11 @@ public class CudaZeroHandler implements MemoryHandler { return ctx; } + @Override + public void resetCachedContext() { + tlContext.remove(); + } + /** * This method returns if this MemoryHandler instance is device-dependant (i.e. CUDA) * From 217fa433d4250d44d09f03159b44af3219902c03 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 16 Apr 2020 07:39:45 +0300 Subject: [PATCH 02/18] mkldnn repo url updated Signed-off-by: raver119 --- libnd4j/CMakeLists.txt.mkldnn.in | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libnd4j/CMakeLists.txt.mkldnn.in b/libnd4j/CMakeLists.txt.mkldnn.in index 224f5d50d..bfa508eef 100644 --- a/libnd4j/CMakeLists.txt.mkldnn.in +++ b/libnd4j/CMakeLists.txt.mkldnn.in @@ -4,7 +4,7 @@ project(mkldnn-download NONE) include(ExternalProject) ExternalProject_Add(mkldnn - GIT_REPOSITORY https://github.com/intel/mkl-dnn.git + GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git GIT_TAG v1.3 SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build" From 4247718f61ebb061085bc4263612426aa39ccaff Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Thu, 16 Apr 2020 08:09:04 +0300 Subject: [PATCH 03/18] Shyrma gru bp (#377) * - update gru ff op Signed-off-by: Yurii * - implementation and testing gru_bp op Signed-off-by: Yurii * - neglect dependencies between dLdh/dLdhLast/dLdcLast in lstmLayer backprop Signed-off-by: Yurii --- libnd4j/include/helpers/GradCheck.h | 3 +- libnd4j/include/helpers/impl/GradCheck.cpp | 48 +- .../declarable/generic/nn/recurrent/gru.cpp | 179 ++++-- .../generic/nn/recurrent/gruCell.cpp | 2 +- .../generic/nn/recurrent/lstmLayer.cpp | 15 +- .../generic/nn/recurrent/lstmLayerCell.cpp | 2 +- .../ops/declarable/headers/recurrent.h | 4 + .../ops/declarable/helpers/cpu/gru.cpp | 421 -------------- .../ops/declarable/helpers/cuda/gru.cu | 365 ------------ libnd4j/include/ops/declarable/helpers/gru.h | 18 +- .../ops/declarable/helpers/impl/gru.cpp | 546 ++++++++++++++++++ .../ops/declarable/helpers/impl/lstmLayer.cpp | 254 ++++---- .../ops/declarable/helpers/lstmLayer.h | 2 +- .../layers_tests/DeclarableOpsTests13.cpp | 178 +++--- .../layers_tests/DeclarableOpsTests15.cpp | 66 +++ 15 files changed, 968 insertions(+), 1135 deletions(-) delete mode 100644 libnd4j/include/ops/declarable/helpers/cpu/gru.cpp delete mode 100644 libnd4j/include/ops/declarable/helpers/cuda/gru.cu create mode 100644 libnd4j/include/ops/declarable/helpers/impl/gru.cpp diff --git a/libnd4j/include/helpers/GradCheck.h b/libnd4j/include/helpers/GradCheck.h index f5fd1f3df..9ca18a82b 100644 --- a/libnd4j/include/helpers/GradCheck.h +++ b/libnd4j/include/helpers/GradCheck.h @@ -50,10 +50,9 @@ class ND4J_EXPORT GradCheck { * whatArrsToCheck - specifies what output gradient arrays to check, for example {0, 1, 0} means that only second output gradient array will be checked, default value is empty std::vector which means to check all arrays * IdxRange - specifies indexes range over which array elements will be checked, for example {0.2, 0.7} means range [0.2*array_length, 0.7*array_length), default value is {0., 1.} * loss - type of scalar loss function, it specifies what elements values will be filled into input gradient arrays automatically, default value is SUM - * outArrsFFIdx - contains indexes of ff output arrays which are independent from each other, default means all are independent */ static bool checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck = std::vector(), const std::vector& IdxRange = {0., 1.}, const LossFunc loss = SUM, const std::vector& outArrsFFIdx = {}); + const std::vector& whatArrsToCheck = std::vector(), const std::vector& IdxRange = {0., 1.}, const LossFunc loss = SUM); }; diff --git a/libnd4j/include/helpers/impl/GradCheck.cpp b/libnd4j/include/helpers/impl/GradCheck.cpp index f3daa798c..12ecab75f 100644 --- a/libnd4j/include/helpers/impl/GradCheck.cpp +++ b/libnd4j/include/helpers/impl/GradCheck.cpp @@ -49,7 +49,7 @@ void GradCheck::fillGradArrays(const LossFunc loss, const std::vector& ////////////////////////////////////////////////////////////////////////// bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, const OpArgsHolder& argsHolderFF, const OpArgsHolder& argsHolderBP, - const std::vector& whatArrsToCheck, const std::vector& idxRange, const LossFunc loss, const std::vector& outArrsFFIdx) { + const std::vector& whatArrsToCheck, const std::vector& idxRange, const LossFunc loss) { const int numInArrsFF = argsHolderFF.getNumInArrs(); // at the same time numInArrsFF = number of output arrays in opBP const int numInGradArrsBP = argsHolderBP.getNumInArrs() - numInArrsFF; // because argsHolderBP.getNumInArrs() = numInArrsFF + numInGradArrsBP @@ -82,23 +82,12 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons int numOutArrs = outArrsFF.size(); double scorePlus = 0.; - if(!outArrsFFIdx.empty()) { - for(const auto& k : outArrsFFIdx) { // loop through independent output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scorePlus += tmpScalar.e(0); - } - } - else { - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scorePlus += tmpScalar.e(0); - } + for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scorePlus += tmpScalar.e(0); } // subtract epsilon, feed forward @@ -106,23 +95,12 @@ bool GradCheck::checkGrad(ops::DeclarableOp& opFF, ops::DeclarableOp& opBP, cons outArrsFF = opFF.execute(argsHolderFF); double scoreMinus = 0.; - if(!outArrsFFIdx.empty()) { - for(const auto& k : outArrsFFIdx) { // loop through independent output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scoreMinus += tmpScalar.e(0); - } - } - else { - for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays - if(loss == SUM) - outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); - else - outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); - scoreMinus += tmpScalar.e(0); - } + for(int k = 0; k < numOutArrs; ++k) { // loop through output arrays + if(loss == SUM) + outArrsFF.at(k)->reduceNumber(reduce::Sum, tmpScalar); + else + outArrsFF.at(k)->reduceNumber(reduce::Mean, tmpScalar); + scoreMinus += tmpScalar.e(0); } // restore initial element value diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp index c6cd2e8f1..dee9a7c88 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/gru.cpp @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -15,7 +16,7 @@ ******************************************************************************/ // -// created by Yurii Shyrma on 15.02.2018 +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -30,83 +31,157 @@ namespace ops { ////////////////////////////////////////////////////////////////////////// CUSTOM_OP_IMPL(gru, 5, 1, false, 0, 0) { - auto x = INPUT_VARIABLE(0); // input [time x bS x iS] - auto h0 = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS x nU] + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [iS x 3*nU] - auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nU x 3*nU] - auto b = INPUT_VARIABLE(4); // biases, [3*nU] + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - auto h = OUTPUT_VARIABLE(0); // cell outputs [time x bS x nU], that is per each time step + auto h = OUTPUT_VARIABLE(0); // cell outputs [time, bS, nOut], that is per each time step - const int rank = x->rankOf(); // = 3 - const int time = x->sizeAt(0); - const int bS = x->sizeAt(1); - const int iS = x->sizeAt(2); - const int nU = h0->sizeAt(1); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); - const std::vector h0CorrectShape = {bS, nU}; - const std::vector wxCorrectShape = {iS, 3*nU}; - const std::vector whCorrectShape = {nU, 3*nU}; - const std::vector bCorrectShape = {3*nU}; + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3*nOut}; + const std::vector whCorrectShape = {nOut, 3*nOut}; + const std::vector bCorrectShape = {3*nOut}; - REQUIRE_TRUE(h0->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0).c_str()); + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - helpers::gruTimeLoop(block.launchContext(), x, h0, Wx, Wh, b, h); + helpers::gruTimeLoop(block.launchContext(), x, hI, Wx, Wh, b, h); return Status::OK(); } +////////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(gru) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} - DECLARE_TYPES(gru) { - getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes({ALL_FLOATS}); - } - - +////////////////////////////////////////////////////////////////////////// DECLARE_SHAPE_FN(gru) { - const auto xShapeInfo = inputShape->at(0); // input [time x bS x inSize] - const auto h0ShapeInfo = inputShape->at(1); // initial cell output [bS x numUnits], that is at time step t=0 - const auto WxShapeInfo = inputShape->at(2); // input-to-hidden weights, [inSize x 3*numUnits] - const auto WhShapeInfo = inputShape->at(3); // hidden-to-hidden weights, [numUnits x 3*numUnits] - const auto bShapeInfo = inputShape->at(4); // biases, [3*numUnits] - const int rank = shape::rank(xShapeInfo); // = 3 - const auto time = xShapeInfo[1]; - const auto bS = xShapeInfo[2]; - const auto inSize = xShapeInfo[3]; - const auto numUnits = h0ShapeInfo[2]; + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] - const std::vector h0CorrectShape = {bS, numUnits}; - const std::vector wxCorrectShape = {inSize, 3*numUnits}; - const std::vector whCorrectShape = {numUnits, 3*numUnits}; - const std::vector bCorrectShape = {3*numUnits}; + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] - REQUIRE_TRUE(ShapeUtils::areShapesEqual(h0ShapeInfo, h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(h0ShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WxShapeInfo, wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(WxShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(WhShapeInfo, whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(WhShapeInfo).c_str()); - REQUIRE_TRUE(ShapeUtils::areShapesEqual(bShapeInfo, bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(bShapeInfo).c_str()); + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3*nOut}; + const std::vector whCorrectShape = {nOut, 3*nOut}; + const std::vector bCorrectShape = {3*nOut}; - // evaluate output shapeInfo - Nd4jLong *hShapeInfo(nullptr); - ALLOCATE(hShapeInfo, block.getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); - hShapeInfo[0] = rank; - hShapeInfo[1] = time; - hShapeInfo[2] = bS; - hShapeInfo[3] = numUnits; - - ShapeUtils::updateStridesAndType(hShapeInfo, xShapeInfo, shape::order(h0ShapeInfo)); + auto* hShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(hI->dataType(), hI->ordering(), {time, bS, nOut}); return SHAPELIST(hShapeInfo); } +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(gru_bp, 6, 5, false, 0, 0) { + + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] + + auto dLdx = OUTPUT_VARIABLE(0); // gradient vs. ff input, [time, bS, nIn] + auto dLdhI = OUTPUT_NULLIFIED(1); // gradient vs. initial cell output, [bS, nOut] + auto dLdWx = OUTPUT_NULLIFIED(2); // gradient vs. input-to-hidden weights, [nIn, 3*nOut] + auto dLdWh = OUTPUT_NULLIFIED(3); // gradient vs. hidden-to-hidden weights, [nOut, 3*nOut] + auto dLdb = OUTPUT_NULLIFIED(4); // gradient vs. biases [3*nOut] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3*nOut}; + const std::vector whCorrectShape = {nOut, 3*nOut}; + const std::vector bCorrectShape = {3*nOut}; + const std::vector hCorrectShape = {time, bS, nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape),0, "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); + + helpers::gruTimeLoopBp(block.launchContext(), x, hI, Wx, Wh, b, dLdh, dLdx, dLdhI, dLdWx, dLdWh, dLdb); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(gru_bp) { + getOpDescriptor() + ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedOutputTypes({ALL_FLOATS}); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(gru_bp) { + + auto x = INPUT_VARIABLE(0); // input [time, bS, nIn] + auto hI = INPUT_VARIABLE(1); // initial cell output (at time step = 0) [bS, nOut] + + auto Wx = INPUT_VARIABLE(2); // input-to-hidden weights, [nIn, 3*nOut] + auto Wh = INPUT_VARIABLE(3); // hidden-to-hidden weights, [nOut, 3*nOut] + auto b = INPUT_VARIABLE(4); // biases, [3*nOut] + + auto dLdh = INPUT_VARIABLE(5); // gradient vs. ff output, [time, bS, nOut] + + const int time = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nIn = x->sizeAt(2); + const int nOut = hI->sizeAt(1); + + const std::vector h0CorrectShape = {bS, nOut}; + const std::vector wxCorrectShape = {nIn, 3*nOut}; + const std::vector whCorrectShape = {nOut, 3*nOut}; + const std::vector bCorrectShape = {3*nOut}; + const std::vector hCorrectShape = {time, bS, nOut}; + + REQUIRE_TRUE(hI->isSameShape(h0CorrectShape), 0, "GRU_BP operation: wrong shape of previous cell output array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(h0CorrectShape).c_str(), ShapeUtils::shapeAsString(hI).c_str()); + REQUIRE_TRUE(Wx->isSameShape(wxCorrectShape), 0, "GRU_BP operation: wrong shape of input-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(wxCorrectShape).c_str(), ShapeUtils::shapeAsString(Wx).c_str()); + REQUIRE_TRUE(Wh->isSameShape(whCorrectShape), 0, "GRU_BP operation: wrong shape of hidden-to-hidden weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(whCorrectShape).c_str(), ShapeUtils::shapeAsString(Wh).c_str()); + REQUIRE_TRUE(b->isSameShape(bCorrectShape), 0, "GRU_BP operation: wrong shape of biases array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(bCorrectShape).c_str(), ShapeUtils::shapeAsString(b).c_str()); + REQUIRE_TRUE(dLdh->isSameShape(hCorrectShape),0, "GRU_BP operation: wrong shape of gradient vs. ff output, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); + + Nd4jLong* dLdxShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), x->getShapeInfo()); + Nd4jLong* dLdhIShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), hI->getShapeInfo()); + Nd4jLong* dLdWxShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), Wx->getShapeInfo()); + Nd4jLong* dLdWhShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), Wh->getShapeInfo()); + Nd4jLong* dLdbShapeInfo = ConstantShapeHelper::getInstance()->createShapeInfo(dLdh->dataType(), b->getShapeInfo()); + + return SHAPELIST(dLdxShapeInfo, dLdhIShapeInfo, dLdWxShapeInfo, dLdWhShapeInfo, dLdbShapeInfo); +} + } } diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp index 204a1ca63..037f09736 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/gruCell.cpp @@ -161,7 +161,7 @@ CUSTOM_OP_IMPL(gruCell_bp, 10, 6, false, 0, 0) { REQUIRE_TRUE(dLdc->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdc array (gradient wrt cell state), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdc).c_str()); REQUIRE_TRUE(dLdh->isSameShape(hiCorrectShape), 0, "GRU_CELL_BP op: wrong shape of dLdh array (gradient wrt current cell output), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(hiCorrectShape).c_str(), ShapeUtils::shapeAsString(dLdh).c_str()); - helpers::gruCellBP(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc); + helpers::gruCellBp(block.launchContext(), x, hi, W, Wc, b, bc, dLdr, dLdu, dLdc, dLdh, dLdx, dLdhi, dLdW, dLdWc, dLdb, dLdbc); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp index 8637fe990..871291165 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayer.cpp @@ -727,12 +727,10 @@ CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) { dLdcLBwd = new NDArray((*dLdcL)({1,2, 0,0, 0,0})); } - // FIXME looks like sum (directionMode == 2) is impossible for backprop if(dLdh) { if(directionMode == 2) { // sum - REQUIRE_TRUE(false, 0, "LSTM_LAYER_BP operation: mode for bidirectional sum and dLdh being present has no sense for backpropagation !"); - // dLdhFwd = dLdh; - // dLdhBwd = new NDArray(dLdh->ordering(), dLdh->getShapeAsVector(), dLdh->dataType(), dLdh->getContext()); // automatically nullifies content + dLdhFwd = dLdh; + dLdhBwd = dLdh; } else if(directionMode == 3) { // concat dLdhFwd = new NDArray(dataFormat <= 1 ? (*dLdh)({0,0, 0,0, 0,nOut}) : (*dLdh)({0,0, 0,nOut, 0,0})); @@ -744,21 +742,20 @@ CUSTOM_OP_IMPL(lstmLayer_bp, 4, 1, false, 1, 5) { } } + NDArray dLdxBwd = dLdx->ulike(); - + // FIXME - following two calls are independent and may run in different streams helpers::lstmLayerTimeLoopBp(x, &WxFwd, &WrFwd, bFwd, seqLen, hIFwd, cIFwd, WpFwd, dLdhFwd, dLdhLFwd, dLdcLFwd, params, true, dLdx, &dLdWxFwd, &dLdWrFwd, dLdbFwd, dLdhIFwd, dLdcIFwd, dLdWpFwd); - NDArray dLdxBwd = dLdx->ulike(); helpers::lstmLayerTimeLoopBp(x, &WxBwd, &WrBwd, bBwd, seqLen, hIBwd, cIBwd, WpBwd, dLdhBwd, dLdhLBwd, dLdcLBwd, params, false, &dLdxBwd, &dLdWxBwd, &dLdWrBwd, dLdbBwd, dLdhIBwd, dLdcIBwd, dLdWpBwd); *dLdx += dLdxBwd; delete WpFwd; delete WpBwd; delete bFwd; delete bBwd; delete hIFwd; delete hIBwd; delete cIFwd; delete cIBwd; - delete dLdhBwd; delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd; + delete dLdhLFwd; delete dLdhLBwd; delete dLdcLFwd; delete dLdcLBwd; delete dLdWpFwd; delete dLdWpBwd; delete dLdbFwd; delete dLdbBwd; delete dLdhIFwd; delete dLdhIBwd; delete dLdcIFwd; delete dLdcIBwd; - if(dLdhFwd != dLdh) - delete dLdhFwd; + if(!(dLdh && directionMode == 2)) { delete dLdhFwd; delete dLdhBwd; } } return Status::OK(); diff --git a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp index 46f32e399..4f24219bd 100644 --- a/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/recurrent/lstmLayerCell.cpp @@ -293,7 +293,7 @@ CUSTOM_OP_IMPL(lstmLayerCellBp, 7, 5, false, 1, 3) { helpers::lstmLayerCell(x,Wx, Wr, b, hI, cI, Wp, params, &z, &a, &h, &c); - helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp); + helpers::lstmLayerCellBp(x, Wx, Wr, b, hI, cI, Wp, dLdh, nullptr, nullptr, &z, &a, &c, params, dLdx, dLdWx, dLdWr, dLdhI, dLdcI, dLdb, dLdWp); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/headers/recurrent.h b/libnd4j/include/ops/declarable/headers/recurrent.h index dd219867f..aeeae24c4 100644 --- a/libnd4j/include/ops/declarable/headers/recurrent.h +++ b/libnd4j/include/ops/declarable/headers/recurrent.h @@ -345,6 +345,10 @@ namespace ops { DECLARE_CUSTOM_OP(gru, 5, 1, false, 0, 0); #endif + #if NOT_EXCLUDED(OP_gru) + DECLARE_CUSTOM_OP(gru_bp, 6, 5, false, 0, 0); + #endif + ////////////////////////////////////////////////////////////////////////// /** * Implementation of operation "static RNN time sequences" with peep hole connections: diff --git a/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp b/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp deleted file mode 100644 index b00036b81..000000000 --- a/libnd4j/include/ops/declarable/helpers/cpu/gru.cpp +++ /dev/null @@ -1,421 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018, Alex Black -// - -// implementation of gated Recurrent Unit cell -// (cf. https://arxiv.org/abs/1406.1078). -// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio -// "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" - - -#include -#include -#include -#include - -namespace sd { -namespace ops { -namespace helpers { - - -////////////////////////////////////////////////////////////////////////// -void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, - const NDArray* b, const NDArray* bc, - NDArray* r, NDArray* u, NDArray* c, NDArray* h) { - - //Inputs: - // x input [bS, iS], iS - input size - // hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units - // W RU weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - - //Outputs: - // r Reset gate output [bS, nU] - // u Update gate output [bS, nU] - // c Cell gate output [bS, nU] - // h current cell output [bS, nU] - - /***************************************************************************************/ - /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ - /** however it is more math-friendly and convenient for backprop formulas derivation) **/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - // × means matrix multipication - // * means element-wise product or so called Hadamard product - - // reset gate - r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r->applyTransform(transform::Sigmoid, *r); - - // update gate - u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u->applyTransform(transform::Sigmoid, *u); - - // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) - c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c->applyTransform(transform::Tanh, *c); - - NDArray temp = 1.f - *c * *c; - - // cell output - h->assign(*u * *hLast + (1.f - *u) * *c); - - - /***************************************************************************************/ - /*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/ - /***************************************************************************************/ -/* - //Concat inputs: x + hLast : [bs, iS + nU] - NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU] - helpers::concat(context, {const_cast(x), const_cast(hLast)}, xhConcat, {1}); - - //mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u) - auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU] - // m += *bru; - - m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz) - - r->assign(m({0,0, 0, nU})); - u->assign(m({0,0, nU, 2*nU})); - - // hLast = hLast * r - xhConcat({0,0, iS, iS+nU}) *= *r; - - //c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c) - MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c - *c += *bc; - c->applyTransform(transform::Tanh); - - //Output: h = (1-u).*c + u .* hPrev - //auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast(h)->assign(&hResult); - u->applyPairwiseTransform(pairwise::Multiply, hLast, h, nullptr); //h = u * hLast - auto temp = (1.0f - *u); - temp *= (*c); - (*h) += temp; -*/ -} - -////////////////////////////////////////////////////////////////////////// -void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { - - // x input [time, bS, iS] - // hLast initial cell output (at time step = 0) [bS, nU] - // Wx input-to-hidden weights, [iS, 3*nU] - // Wh hidden-to-hidden weights, [nU, 3*nU] - // b biases, [3*nU] - - // h is cell outputs at each time step [time, bS, nU] - - const int time = x->sizeAt(0); - - NDArray ht_1(*hLast); - - // loop through time steps - for (int t = 0; t < time; ++t) { - - auto xt = (*x)({t,t+1, 0,0, 0,0}); - auto ht = (*h)({t,t+1, 0,0, 0,0}); - - // helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht); - // ht_1.assign(ht); - } -} - -////////////////////////////////////////////////////////////////////////// -void gruCellBP(sd::LaunchContext* context, - const NDArray* x, const NDArray* hLast, - const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, - const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhLast, - NDArray* dLdW, NDArray* dLdWc, - NDArray* dLdb, NDArray* dLdbc) { - - //Inputs: - // x input [bS, iS] - // hLast previous cell output [bS, nU], that is at previous time step t-1 - // W weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - // dLdr gradient wrt reset gate, [bS, nU] - // dLdu gradient wrt update gate, [bS, nU] - // dLdc gradient wrt cell state, [bS, nU] - // dLdh gradient wrt current cell output, [bS, nU] - - //Outputs: - // dLdx gradient wrt x, [bS, iS], - // dLdhLast gradient wrt hLast, [bS, nU] - // dLdW gradient wrt W, [iS+nU, 2*nU] - // dLdWc gradient wrt Wc, [iS+nU, nU] - // dLdb gradient wrt bru [2*nU] - // dLdbc gradient wrt bc [nU] - - // * means element-wise product or so called Hadamard product - // × means matrix multiplication - - /************************************************************************************************/ - /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ - /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray xT = x->transpose(); // [iS, bS] - NDArray hLastT = hLast->transpose(); // [nU, bS] - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - NDArray WrxT = Wrx.transpose(); // [nU, iS] - NDArray WuxT = Wux.transpose(); // [nU, iS] - NDArray WrhT = Wrh.transpose(); // [nU, nU] - NDArray WuhT = Wuh.transpose(); // [nU, nU] - - NDArray WcxT = Wcx.transpose(); // [nU, iS] - NDArray WchT = Wch.transpose(); // [nU, nU] - - NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU] - NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU] - NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU] - NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU] - - NDArray dLdbr = (*dLdb)({0, nU}); // [nU] - NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU] - - - // ***** feed forward step ***** // - - // reset gate - NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r.applyTransform(transform::Sigmoid, r); - - // update gate - NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u.applyTransform(transform::Sigmoid, u); - - // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) - NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c.applyTransform(transform::Tanh, c); - - // h = (1 - u) * c + u * hPrev - - - // ***** back prop step ***** // - - // notations: - // Zr = x × Wrx + hLast × Wrh + br - // Zu = x × Wux + hLast × Wuh + bu - // Sr = sigmoid(Zr) - // Su = sigmoid(Zu) - // Zc = x × Wcx + (r * hlast) × Wch + bc - - - // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx - // = dLdx_u + dLdx_c - // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT - // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1 - // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT - // dZcdr = (... * hLast) × WchT - // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT - // drdx = drdZr * dZrdx - // dZrdx = ... × WrxT - // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT - // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT - - - // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast - // = dLdhLast_h + dLdhLast_u + dLdhLast_c - // dLdhLast_h = dLdh * dhdhLas = dLdh * u - // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT - // dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = - // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = - // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1 - // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT - // dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT - // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = - // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT - - - // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx - // dZrdWrx = xT × ... - // finally dLdWrx = xT × (dLdr * drdZr) - - - // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh - // dZrdWrh = hLastT × ... - // finally dLdWrh = hLastT × (dLdr * drdZr) - - - // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux - // dZudWux = xT × ... - // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) - - - // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh - // dZudWuh = hLastT × ... - // finally dLdWuh = hLastT × (dLdu * dudZu) - - - // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx - // dZcdWcx = xT × ... - // finally dLdWcx = xT × (dLdc * dcdZc) - - - // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch - // dZcdWch = (r*hLast)^T × ... - // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc) - - - // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = - // = dLdr * drdZr * dZrdbr - // dZrdbr = 1 - // finally dLdbr = dLdr * drdZr - - - // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu - // dZudbu = 1 - // finally dLdbu = dLdu * dudZu - - - // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc - // dZcdbc = 1 - // finally dLdbc = dLdc * dcdZc - - NDArray dhdc = 1.f - u; // [bS, nU] - NDArray dhdu = *hLast - c; // [bS, nU] - NDArray dudZu = u * dhdc; // [bS, nU] - NDArray drdZr = r * (1.f - r); // [bS, nU] - NDArray dcdZc = 1.f - c * c; // [bS, nU] - NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] - NDArray dLdZu = *dLdu * dudZu; // [bS, nU] - NDArray dLdZr = *dLdr * drdZr; // [bS, nU] - - // NDArray dLdc = *dLdh * dhdc; // [bS, nU] - // NDArray dLdu = *dLdh * dhdu; // [bS, nU] - // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] - - dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS] - - dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU] - - dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] - dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] - dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] - - dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] -} - -// ////////////////////////////////////////////////////////////////////////// -// FIXME - gruTimeLoopBP is not correct -// template -// void gruTimeLoopBP(const std::vector*>& inArrs, const std::vector*>& outArrs) { - -// NDArray* x = inArrs[0]; // input [time, bS, iS] -// NDArray* hi = inArrs[1]; // previous/initial cell output [bS, nU], that is at previous time step t-1 -// NDArray* Wx = inArrs[2]; // input-to-hidden weights, [iS, 3*nU] -// NDArray* Wh = inArrs[3]; // hidden-to-hidden weights, [nU, 3*nU] -// NDArray* b = inArrs[4]; // biases, [3*nU] -// NDArray* dLdh = inArrs[5]; // gradient wrt output, [time, bS, nU], that is epsilon_next - -// NDArray* dLdx = outArrs[0]; // gradient wrt x, [time, bS, iS], that is epsilon -// NDArray* dLdhi = outArrs[1]; // gradient wrt hi, [bS, nU] -// NDArray* dLdWx = outArrs[2]; // gradient wrt Wx, [iS, 3*nU] -// NDArray* dLdWh = outArrs[3]; // gradient wrt Wh, [nU, 3*nU] -// NDArray* dLdb = outArrs[4]; // gradient wrt b, [3*nU] - -// const Nd4jLong time = x->sizeAt(0); -// const Nd4jLong bS = x->sizeAt(1); -// const Nd4jLong iS = x->sizeAt(2); -// const Nd4jLong nU = hi->sizeAt(1); - -// NDArray h(hi->ordering(), {time, bS, nU}); // feed forward output - -// // first step, time = 0, feed forward -// NDArray x0 = (*x)({{0,1}, {}, {}}); -// NDArray hLast = h({{0,1}, {}, {}}); -// helpers::gruCell({&x0, hi, Wx, Wh, b}, &hLast); - -// // first step, time = 0, back prop -// NDArray dLdx0 = (*dLdx)({{0,1}, {}, {}}); -// NDArray dLdhLast = (*dLdh)({{0,1}, {}, {}}); -// helpers::gruCellBP({&x0, hi, Wx, Wh, b, &dLdhLast, nullptr, nullptr, nullptr}, {&dLdx0, dLdhi, dLdWx, dLdWh, dLdb}); - -// // loop through the rest time steps -// for (Nd4jLong t = time-1; t > 0; --t) { -// for (Nd4jLong t = 1; t < time; ++t) { - -// NDArray xt = (*x)({{t,t+1}, {}, {}}); -// NDArray ht = h({{t,t+1}, {}, {}}); -// NDArray ht_1 = h({{t-1,t}, {}, {}}); -// NDArray dLdxt = (*dLdx)({{t,t+1}, {}, {}}); -// NDArray dLdht = (*dLdh)({{t,t+1}, {}, {}}); - -// NDArray dLdWxt_1 = dLdWx; -// NDArray dLdWht_1 = dLdWh; -// NDArray dLdbt_1 = dLdb; - -// // feed forward, calculation of ht -// helpers::gruCell({&xt, &ht_1, Wx, Wh, b}, &ht); - -// // back prop -// helpers::gruCellBP({&xt, &ht_1, Wx, Wh, b, &dLdht, &dLdWxt_1, &dLdWht_1, &dLdbt_1}, {&dLdxt, nullptr, dLdWx, dLdWh, dLdb}); -// } -// } - - -} -} -} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gru.cu b/libnd4j/include/ops/declarable/helpers/cuda/gru.cu deleted file mode 100644 index bd4e878e3..000000000 --- a/libnd4j/include/ops/declarable/helpers/cuda/gru.cu +++ /dev/null @@ -1,365 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018 -// - -// implementation of gated Recurrent Unit cell -// (cf. https://arxiv.org/abs/1406.1078). -// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio -// "Learning Phrase Representations using RNN Encoder-Decoder for Statistical Machine Translation" - - -#include -#include -#include -#include - -namespace sd { -namespace ops { -namespace helpers { - - -////////////////////////////////////////////////////////////////////////// -void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, - const NDArray* b, const NDArray* bc, - NDArray* r, NDArray* u, NDArray* c, NDArray* h) { - - //Inputs: - // x input [bS, iS], iS - input size - // hLast previous cell output [bS, nU], that is at previous time step t-1, nU - number of units - // W RU weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - - //Outputs: - // r Reset gate output [bS, nU] - // u Update gate output [bS, nU] - // c Cell gate output [bS, nU] - // h current cell output [bS, nU] - - /***************************************************************************************/ - /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ - /** however it is more math-friendly and convenient for backprop formulas derivation) **/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - // × means matrix multipication - // * means element-wise product or so called Hadamard product - - // reset gate - r->assign(mmul(*x, Wrx) + mmul(*hLast, Wrh) + br); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r->applyTransform(transform::Sigmoid, *r); - - // update gate - u->assign(mmul(*x, Wux) + mmul(*hLast, Wuh) + bu); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u->applyTransform(transform::Sigmoid, *u); - - // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) - c->assign(mmul(*x, Wcx) + mmul(*r * *hLast, Wch) + *bc); // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c->applyTransform(transform::Tanh, *c); - - NDArray temp = 1.f - *c * *c; - - // cell output - h->assign(*u * *hLast + (1.f - *u) * *c); - - - /***************************************************************************************/ - /*************** THIS IS MORE OPTIMAZED CODE (should think about concat) ***************/ - /***************************************************************************************/ -/* - //Concat inputs: x + hLast : [bs, iS + nU] - NDArray xhConcat(x->ordering(), {bS, iS + nU}, x->dataType(), context); // concat([bs, iS], [bs, nU]) -> [bs, iS + nU] - helpers::concat(context, {const_cast(x), const_cast(hLast)}, xhConcat, {1}); - - //mmul for reset and update gates: (x × weight_ux + hLast × weight_xr + b_u) - auto m = mmul(xhConcat, *W) + *b ; // [bs, iS+nU] * [iS+nU, 2*nU] = [bs, 2*nU] - // m += *bru; - - m.applyTransform(transform::Sigmoid); //sigmoid(rz) and sigmoid(uz) - - r->assign(m({0,0, 0, nU})); - u->assign(m({0,0, nU, 2*nU})); - - // hLast = hLast * r - xhConcat({0,0, iS, iS+nU}) *= *r; - - //c = tanh(x × weight_cx + (hLast * r) × weight_cr + b_c) - MmulHelper::mmul(&xhConcat, Wc, c, 1.0, 0.0); //c = 1.0 * xhConcat * Wc + 0.0 * c - *c += *bc; - c->applyTransform(transform::Tanh); - - //Output: h = (1-u).*c + u .* hPrev - //auto hResult = (*u) * (*hLast) + (1.0f - *u) * (*c); const_cast(h)->assign(&hResult); - u->applyPairwiseTransform(pairwise::Multiply, hLast, h, nullptr); //h = u * hLast - auto temp = (1.0f - *u); - temp *= (*c); - (*h) += temp; -*/ -} - -////////////////////////////////////////////////////////////////////////// -void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { - - // x input [time, bS, iS] - // hLast initial cell output (at time step = 0) [bS, nU] - // Wx input-to-hidden weights, [iS, 3*nU] - // Wh hidden-to-hidden weights, [nU, 3*nU] - // b biases, [3*nU] - - // h is cell outputs at each time step [time, bS, nU] - - const int time = x->sizeAt(0); - - NDArray ht_1(*hLast); - - // loop through time steps - for (int t = 0; t < time; ++t) { - - auto xt = (*x)({t,t+1, 0,0, 0,0}); - auto ht = (*h)({t,t+1, 0,0, 0,0}); - - // helpers::gruCell(&xt, &ht_1, Wx, Wh, b, &ht); - // ht_1.assign(ht); - } -} - -////////////////////////////////////////////////////////////////////////// -void gruCellBP(sd::LaunchContext* context, - const NDArray* x, const NDArray* hLast, - const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, - const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, - NDArray* dLdx, NDArray* dLdhLast, - NDArray* dLdW, NDArray* dLdWc, - NDArray* dLdb, NDArray* dLdbc) { - - //Inputs: - // x input [bS, iS] - // hLast previous cell output [bS, nU], that is at previous time step t-1 - // W weights - [iS+nU, 2*nU] - reset and update gates - // Wc C weights - [iS+nU, nU] - cell gate - // b r and u biases, [2*nU] - reset and update gates - // bc c biases, [nU] - cell gate - // dLdr gradient wrt reset gate, [bS, nU] - // dLdu gradient wrt update gate, [bS, nU] - // dLdc gradient wrt cell state, [bS, nU] - // dLdh gradient wrt current cell output, [bS, nU] - - //Outputs: - // dLdx gradient wrt x, [bS, iS], - // dLdhLast gradient wrt hLast, [bS, nU] - // dLdW gradient wrt W, [iS+nU, 2*nU] - // dLdWc gradient wrt Wc, [iS+nU, nU] - // dLdb gradient wrt bru [2*nU] - // dLdbc gradient wrt bc [nU] - - // * means element-wise product or so called Hadamard product - // × means matrix multiplication - - /************************************************************************************************/ - /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ - /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ - - const int bS = x->sizeAt(0); - const int iS = x->sizeAt(1); - const int nU = hLast->sizeAt(1); - - NDArray xT = x->transpose(); // [iS, bS] - NDArray hLastT = hLast->transpose(); // [nU, bS] - - NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] - NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] - NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] - NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] - - NDArray br = (*b)({0, nU}); // [nU] - NDArray bu = (*b)({nU, 2*nU}); // [nU] - - NDArray WrxT = Wrx.transpose(); // [nU, iS] - NDArray WuxT = Wux.transpose(); // [nU, iS] - NDArray WrhT = Wrh.transpose(); // [nU, nU] - NDArray WuhT = Wuh.transpose(); // [nU, nU] - - NDArray WcxT = Wcx.transpose(); // [nU, iS] - NDArray WchT = Wch.transpose(); // [nU, nU] - - NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU] - NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU] - NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU] - NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU] - - NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU] - NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU] - - NDArray dLdbr = (*dLdb)({0, nU}); // [nU] - NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU] - - - // ***** feed forward step ***** // - - // reset gate - NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - r.applyTransform(transform::Sigmoid, r); - - // update gate - NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - u.applyTransform(transform::Sigmoid, u); - - // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) - NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] - c.applyTransform(transform::Tanh, c); - - // h = (1 - u) * c + u * hPrev - - - // ***** back prop step ***** // - - // notations: - // Zr = x × Wrx + hLast × Wrh + br - // Zu = x × Wux + hLast × Wuh + bu - // Sr = sigmoid(Zr) - // Su = sigmoid(Zu) - // Zc = x × Wcx + (r * hlast) × Wch + bc - - - // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx - // = dLdx_u + dLdx_c - // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT - // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1 - // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT - // dZcdr = (... * hLast) × WchT - // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT - // drdx = drdZr * dZrdx - // dZrdx = ... × WrxT - // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT - // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT - - - // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast - // = dLdhLast_h + dLdhLast_u + dLdhLast_c - // dLdhLast_h = dLdh * dhdhLas = dLdh * u - // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT - // dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = - // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = - // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1 - // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT - // dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT - // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = - // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT - - - // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx - // dZrdWrx = xT × ... - // finally dLdWrx = xT × (dLdr * drdZr) - - - // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh = - // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh - // dZrdWrh = hLastT × ... - // finally dLdWrh = hLastT × (dLdr * drdZr) - - - // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux - // dZudWux = xT × ... - // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) - - - // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh - // dZudWuh = hLastT × ... - // finally dLdWuh = hLastT × (dLdu * dudZu) - - - // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx - // dZcdWcx = xT × ... - // finally dLdWcx = xT × (dLdc * dcdZc) - - - // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch - // dZcdWch = (r*hLast)^T × ... - // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc) - - - // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = - // = dLdr * drdZr * dZrdbr - // dZrdbr = 1 - // finally dLdbr = dLdr * drdZr - - - // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu - // dZudbu = 1 - // finally dLdbu = dLdu * dudZu - - - // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc - // dZcdbc = 1 - // finally dLdbc = dLdc * dcdZc - - NDArray dhdc = 1.f - u; // [bS, nU] - NDArray dhdu = *hLast - c; // [bS, nU] - NDArray dudZu = u * dhdc; // [bS, nU] - NDArray drdZr = r * (1.f - r); // [bS, nU] - NDArray dcdZc = 1.f - c * c; // [bS, nU] - NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] - NDArray dLdZu = *dLdu * dudZu; // [bS, nU] - NDArray dLdZr = *dLdr * drdZr; // [bS, nU] - - // NDArray dLdc = *dLdh * dhdc; // [bS, nU] - // NDArray dLdu = *dLdh * dhdu; // [bS, nU] - // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] - - dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS] - - dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU] - - dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] - dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] - dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] - - dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] - dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] - - dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] -} - - -} -} -} - diff --git a/libnd4j/include/ops/declarable/helpers/gru.h b/libnd4j/include/ops/declarable/helpers/gru.h index 3fecfa71b..9e98e4046 100644 --- a/libnd4j/include/ops/declarable/helpers/gru.h +++ b/libnd4j/include/ops/declarable/helpers/gru.h @@ -31,10 +31,26 @@ namespace helpers { const NDArray* bru, const NDArray* bc, NDArray* r, NDArray* u, NDArray* c, NDArray* h); + void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hLast, const NDArray* Wru, const NDArray* Wc, const NDArray* b, + NDArray* gates, NDArray* h); + void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* h0, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h); - void gruCellBP(sd::LaunchContext* context, const NDArray* x, const NDArray* hLast, const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, NDArray* dLdx, NDArray* dLdhLast, NDArray* dLdW, NDArray* dLdWc, NDArray* dLdb, NDArray* dLdbc); + void gruCellBp(sd::LaunchContext* context, + const NDArray* x, const NDArray* hLast, + const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, + const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhLast, + NDArray* dLdW, NDArray* dLdWc, + NDArray* dLdb, NDArray* dLdbc); + void gruCellBp(sd::LaunchContext* context, + const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates, + NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); + + void gruTimeLoopBp(sd::LaunchContext * context, + const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb); } } } diff --git a/libnd4j/include/ops/declarable/helpers/impl/gru.cpp b/libnd4j/include/ops/declarable/helpers/impl/gru.cpp new file mode 100644 index 000000000..277188428 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/impl/gru.cpp @@ -0,0 +1,546 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * ThnIn program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which nIn available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * dnIntributed under the License nIn dnIntributed on an "AS nIn" BASnIn, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permnInsions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 15.02.2018, Alex Black +// + +// implementation of gated Recurrent Unit cell +// (cf. https://arxiv.org/abs/1406.1078). +// Kyunghyun Cho, Bart van Merrienboer, Caglar Gulcehre, Dzmitry Bahdanau, Fethi Bougares, Holger Schwenk, Yoshua Bengio +// "Learning Phrase Representations using RNN Encoder-Decoder for StatnIntical Machine Translation" + + +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* W, const NDArray* Wc, + const NDArray* b, const NDArray* bc, + NDArray* r, NDArray* u, NDArray* c, NDArray* h) { + + //Inputs: + // x input [bS, nIn], nIn - input size + // hI previous cell output [bS, nOut], that is at previous time step t-1, nOut - number of units + // W RU weights - [nIn+nOut, 2*nOut] - reset and update gates + // Wc C weights - [nIn+nOut, nOut] - cell gate + // b r and u biases, [2*nOut] - reset and update gates + // bc c biases, [nOut] - cell gate + + //Outputs: + // r Reset gate output [bS, nOut] + // u Update gate output [bS, nOut] + // c Cell gate output [bS, nOut] + // h current cell output [bS, nOut] + + /***************************************************************************************/ + /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ + /** however it is more math-friendly and convenient for backprop formulas derivation) **/ + + const int bS = x->sizeAt(0); + const int nIn = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray Wrx = (*W)({0,nIn, 0,nOut}); // [nIn, nOut] + NDArray Wux = (*W)({0,nIn, nOut,2*nOut}); // [nIn, nOut] + NDArray Wrh = (*W)({nIn,nIn+nOut, 0,nOut}); // [nOut, nOut] + NDArray Wuh = (*W)({nIn,nIn+nOut, nOut,2*nOut}); // [nOut, nOut] + + NDArray Wcx = (*Wc)({0,nIn, 0,0}); // reset cell weights [nIn, nOut] + NDArray Wch = (*Wc)({nIn,nIn+nOut, 0,0}); // updates cell weights [nOut, nOut] + + NDArray br = (*b)({0, nOut}); // [nOut] + NDArray bu = (*b)({nOut, 2*nOut}); // [nOut] + + // × means matrix multipication + // * means element-wise product or so called Hadamard product + + // reset gate + r->assign(mmul(*x, Wrx) + mmul(*hI, Wrh) + br); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] + r->applyTransform(transform::Sigmoid, *r); + + // update gate + u->assign(mmul(*x, Wux) + mmul(*hI, Wuh) + bu); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] + u->applyTransform(transform::Sigmoid, *u); + + // cell gate c = activation(x × Wcx + (r * hlast) × Wch + bc) + c->assign(mmul(*x, Wcx) + mmul(*r * *hI, Wch) + *bc); // [bS, nIn] × [nIn, nOut] + [bS, nOut] × [nOut, nOut] + [nOut] = [bS, nOut] + c->applyTransform(transform::Tanh, *c); + + // cell output + h->assign(*u * *hI + (1.f - *u) * *c); +} + +////////////////////////////////////////////////////////////////////////// +void gruCell(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, + NDArray* gates, NDArray* h) { + + //Inputs: + // x input [bS, nIn] + // hI previous cell output [bS, nOut], that is at previous time step t-1 + // Wx weights for x - [nIn, 3*nOut] + // Wh weights for h - [nOut, 3*nOut] + // b biases [3*nOut] + + // 3*nOut means following sequence: reset, update, cell + + //Outputs: + // gates [bS, 3*nOut] = reset gate [bS, nOut] + update gate [bS, nOut] + cell gate [bS, nOut] + // h current cell output [bS, nOut] + + // formulas: + // zr = x × Wxr + hI × Whr + br + // zu = x × Wxu + hI × Whu + bu + // r = sigmoid(zr) + // u = sigmoid(zu) + // zc = x × Wxc + (r * hI) × Whc + bc + // c = tanh(zc) + // h = (1-u)*c + u*hI + + const int bS = x->sizeAt(0); + const int nIn = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray temp = gates->ulike(); + MmulHelper::mmul(x, Wx, &temp); // [bS, nIn] × [nIn, 3*nOut] = [bS, 3*nOut] + temp += *b; + + MmulHelper::mmul(hI, Wh, gates); // [bS, nOut] × [nOut, 3*nOut] = [bS, 3*nOut] + + NDArray ru = (*gates)({0,0, 0,2*nOut}); // [bS, 2*nOut] + + NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut] + NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut] + NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut] + + // reset and update gates + ru += temp({0,0, 0,2*nOut}); + ru.applyTransform(transform::Sigmoid, ru); + + // cell gate + c.assign(c*r + temp({0,0, 2*nOut, 3*nOut})); + c.applyTransform(transform::Tanh, c); + + // cell output + h->assign(u * *hI + (1.f - u) * c); +} + +////////////////////////////////////////////////////////////////////////// +void gruTimeLoop(sd::LaunchContext * context, const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, NDArray* h) { + + // sL means time steps + + // x input [sL, bS, nIn] + // hI initial cell output (at time step = 0) [bS, nOut] + // Wx input-to-hidden weights, [nIn, 3*nOut] + // Wh hidden-to-hidden weights, [nOut, 3*nOut] + // b biases, [3*nOut] + + // h cell outputs at each time step [sL, bS, nOut] + + const int sL = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray gates(h->ordering(), {bS, 3*nOut}, h->dataType(), context); + + auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] + auto hSet = h->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] + + // time loop + for (int t = 0; t < sL; ++t) + gruCell(context, xSet.at(t), t == 0 ? hI : hSet.at(t-1), Wx, Wh, b, &gates, hSet.at(t)); +} + +////////////////////////////////////////////////////////////////////////// +void gruCellBp(sd::LaunchContext* context, + const NDArray* x, const NDArray* hLast, + const NDArray* W, const NDArray* Wc, const NDArray* b, const NDArray* bc, + const NDArray* dLdr, const NDArray* dLdu, const NDArray* dLdc, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhLast, + NDArray* dLdW, NDArray* dLdWc, + NDArray* dLdb, NDArray* dLdbc) { + + //Inputs: + // x input [bS, iS] + // hLast previous cell output [bS, nU], that is at previous time step t-1 + // W weights - [iS+nU, 2*nU] - reset and update gates + // Wc C weights - [iS+nU, nU] - cell gate + // b r and u biases, [2*nU] - reset and update gates + // bc c biases, [nU] - cell gate + // dLdr gradient wrt reset gate, [bS, nU] + // dLdu gradient wrt update gate, [bS, nU] + // dLdc gradient wrt cell state, [bS, nU] + // dLdh gradient wrt current cell output, [bS, nU] + + //Outputs: + // dLdx gradient wrt x, [bS, iS], + // dLdhLast gradient wrt hLast, [bS, nU] + // dLdW gradient wrt W, [iS+nU, 2*nU] + // dLdWc gradient wrt Wc, [iS+nU, nU] + // dLdb gradient wrt bru [2*nU] + // dLdbc gradient wrt bc [nU] + + // * means element-wise product or so called Hadamard product + // × means matrix multiplication + + /************************************************************************************************/ + /******************************* THIS IS NOT OPTIMAZED CODE *************************************/ + /*** aim is to have math-readable code in order to keep track of backprop formulas derivation ***/ + + const int bS = x->sizeAt(0); + const int iS = x->sizeAt(1); + const int nU = hLast->sizeAt(1); + + NDArray xT = x->transpose(); // [iS, bS] + NDArray hLastT = hLast->transpose(); // [nU, bS] + + NDArray Wrx = (*W)({0,iS, 0,nU}); // [iS, nU] + NDArray Wux = (*W)({0,iS, nU,2*nU}); // [iS, nU] + NDArray Wrh = (*W)({iS,iS+nU, 0,nU}); // [nU, nU] + NDArray Wuh = (*W)({iS,iS+nU, nU,2*nU}); // [nU, nU] + + NDArray Wcx = (*Wc)({0,iS, 0,0}); // reset cell weights [iS, nU] + NDArray Wch = (*Wc)({iS,iS+nU, 0,0}); // updates cell weights [nU, nU] + + NDArray br = (*b)({0, nU}); // [nU] + NDArray bu = (*b)({nU, 2*nU}); // [nU] + + NDArray WrxT = Wrx.transpose(); // [nU, iS] + NDArray WuxT = Wux.transpose(); // [nU, iS] + NDArray WrhT = Wrh.transpose(); // [nU, nU] + NDArray WuhT = Wuh.transpose(); // [nU, nU] + + NDArray WcxT = Wcx.transpose(); // [nU, iS] + NDArray WchT = Wch.transpose(); // [nU, nU] + + NDArray dLdWrx = (*dLdW)({0,iS, 0,nU}); // [iS, nU] + NDArray dLdWux = (*dLdW)({0,iS, nU,2*nU}); // [iS, nU] + NDArray dLdWrh = (*dLdW)({iS,iS+nU, 0,nU}); // [nU, nU] + NDArray dLdWuh = (*dLdW)({iS,iS+nU, nU,2*nU}); // [nU, nU] + + NDArray dLdWcx = (*dLdWc)({0,iS, 0,0}); // [iS, nU] + NDArray dLdWch = (*dLdWc)({iS,iS+nU, 0,0}); // [nU, nU] + + NDArray dLdbr = (*dLdb)({0, nU}); // [nU] + NDArray dLdbu = (*dLdb)({nU, 2*nU}); // [nU] + + + // ***** feed forward step ***** // + + // reset gate + NDArray r = mmul(*x, Wrx) + mmul(*hLast, Wrh) + br; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + r.applyTransform(transform::Sigmoid, r); + + // update gate + NDArray u = mmul(*x, Wux) + mmul(*hLast, Wuh) + bu; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + u.applyTransform(transform::Sigmoid, u); + + // cell gate c = activation(x×Wcx + (r*hlast)×Wcu + bc) + NDArray c = mmul(*x, Wcx) + mmul(r * *hLast, Wch) + *bc; // [bS, iS] × [iS, nU] + [bS, nU] × [nU, nU] + [nU] = [bS, nU] + c.applyTransform(transform::Tanh, c); + + // h = (1 - u) * c + u * hPrev + + + // ***** back prop step ***** // + + // notations: + // Zr = x × Wrx + hLast × Wrh + br + // Zu = x × Wux + hLast × Wuh + bu + // Sr = sigmoid(Zr) + // Su = sigmoid(Zu) + // Zc = x × Wcx + (r * hlast) × Wch + bc + + + // dLdx = dLdh * dhdx = dLdh * (dhdu * dudx + dhdc * dcdx) = (dLdh * dhdu) * dudx + (dLdh * dhdc) * dcdx = dLdu * dudx + dLdc * dcdx + // = dLdx_u + dLdx_c + // dLdx_u = dLdu * dudx = dLdu * dudZu * dZudx = |dZudx = ... × WuxT| = (dLdu * dudZu) × WuxT + // dLdx_c = dLdc * dcdx = dLdc * dcdZc * (dZcdx + dZcdr * drdx) = dLdc * dcdZc * dZcdx + dLdc * dcdZc * dZcdr * drdx = dLdx_c0 + dLdx_c1 + // dLdx_c0 = dLdc * dcdZc * dZcdx = |dZcdx = ... × WcxT| = (dLdc * dcdZc) × WcxT + // dZcdr = (... * hLast) × WchT + // dLdc * dcdZc * dZcdr = dLdr = (dLdc * dcdZc * hLast) × WchT + // drdx = drdZr * dZrdx + // dZrdx = ... × WrxT + // dLdx_c1 = dLdc * dcdZc * dZcdr * drdx = dLdr * drdx = (dLdr * drdZr) × WrxT + // finally dLdx = dLdx_u + dLdx_c0 + dLdx_c1 = (dLdu * dudZu) × WuxT + (dLdc * dcdZc) × WcxT + (dLdr * drdZr) × WrxT + + + // dLdhLast = dLdh * (dhdhLast + dhdu * dudhLast + dhdc * dcdhLast) = dLdh * dhdhLast + dLdu * dudhLast + dLdc * dcdhLast + // = dLdhLast_h + dLdhLast_u + dLdhLast_c + // dLdhLast_h = dLdh * dhdhLas = dLdh * u + // dLdhLast_u = dLdu * dudhLast = |dudhLast = dudZu * dZudhLast , dZudhLast = ... × WuhT| = (dLdu * dudZu) × WuhT + // dLdhLast_c = dLdc * dcdhLast = dLdc * (dcdZc * dZcdhLast + dcdZc * dZcdr * drdhLast) = + // = dLdc * dcdZc * dZcdhLast + dLdc * dcdZc * dZcdr * drdhLast = + // = dLdc * dcdZc * dZcdhLast + dLdr * drdhLast = dLdhLast_c0 + dLdhLast_c1 + // dLdhLast_c0 = dLdc * dcdZc * dZcdhLast = |dZcdhLast = (... * r) × WchT| = (dLdc * dcdZc * r) × WchT + // dLdhLast_c1 = dLdr * drdhLast = |drdhLast = drdZr * dZrdhLast, dZrdhLast = ... × WrhT| = (dLdr * drdZr) × WrhT + // finally dLdhLast = dLdhLast_h + dLdhLast_u + dLdhLast_c0 + dLdhLast_c1 = + // = dLdh * u + (dLdu * dudZu) × WuhT + (dLdc * dcdZc * r) × WchT + (dLdr * drdZr) × WrhT + + + // dLdWrx = dLdh * dhdWrx = (dLdh * dhdc) * dcdWrx = dLdc * dcdZc * dZcdWrx = dLdc * dcdZc * dZcdr * drdWrx = + // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrx = dLdr * drdZr * dZrdWrx + // dZrdWrx = xT × ... + // finally dLdWrx = xT × (dLdr * drdZr) + + + // dLdWrh = dLdh * dhdWrh = (dLdh * dhdc) * dcdWrh = dLdc * dcdZc * dZcdWrh = dLdc * dcdZc * dZcdr * drdWrh = + // = dLdc * dcdZc * dZcdr * drdZr * dZrdWrh = dLdr * drdZr * dZrdWrh + // dZrdWrh = hLastT × ... + // finally dLdWrh = hLastT × (dLdr * drdZr) + + + // dLdWux = dLdh * dhdWux = (dLdh * dhdu) * dudWux = dLdu * dudZu * dZudWux + // dZudWux = xT × ... + // dLdu * dudZu * dZudWux = xT × (dLdu * dudZu) + + + // dLdWuh = dLdh * dhdWuh = (dLdh * dhdu) * dudWuh = dLdh * dhdu * dudZu * dZudWuh = dLdu * dudZu * dZudWuh + // dZudWuh = hLastT × ... + // finally dLdWuh = hLastT × (dLdu * dudZu) + + + // dLdWcx = dLdh * dhdWcx = dLdh * dhdc * dcdWcx = (dLdh * dhdc) * dcdZc * dZcdWcx = dLdc * dcdZc * dZcdWcx + // dZcdWcx = xT × ... + // finally dLdWcx = xT × (dLdc * dcdZc) + + + // dLdWch = dLdh * dhdWch = dLdh * dhdc * dcdWch = (dLdh * dhdc) * dcdZc * dZcdWch = dLdc * dcdZc * dZcdWch + // dZcdWch = (r*hLast)^T × ... + // finally dLdWch = (r*hLast)^T × (dLdc * dcdZc) + + + // dLdbr = dLdh * dhdbr = (dLdh * dhdc) * dcdbr = dLdc * dcdbr = dLdc * dcdZc * dZcdbr = dLdc * dcdZc * dZcdr * drdbr = + // = dLdr * drdZr * dZrdbr + // dZrdbr = 1 + // finally dLdbr = dLdr * drdZr + + + // dLdbu = dLdh * dhdbu = (dLdh * dhdu) * dudbu = dLdu * dudZu * dZudbu + // dZudbu = 1 + // finally dLdbu = dLdu * dudZu + + + // dLdbc = dLdh * dhdbc = (dLdh * dhdc) * dcdbc = dLdc * dcdZc * dZcdbc + // dZcdbc = 1 + // finally dLdbc = dLdc * dcdZc + + NDArray dhdc = 1.f - u; // [bS, nU] + NDArray dhdu = *hLast - c; // [bS, nU] + NDArray dudZu = u * dhdc; // [bS, nU] + NDArray drdZr = r * (1.f - r); // [bS, nU] + NDArray dcdZc = 1.f - c * c; // [bS, nU] + NDArray dLdZc = *dLdc * dcdZc; // [bS, nU] + NDArray dLdZu = *dLdu * dudZu; // [bS, nU] + NDArray dLdZr = *dLdr * drdZr; // [bS, nU] + + // NDArray dLdc = *dLdh * dhdc; // [bS, nU] + // NDArray dLdu = *dLdh * dhdu; // [bS, nU] + // NDArray dLdr = mmul(dLdc * dcdZc * *hLast, WchT); // [bS, nU] + + dLdx->assign(mmul(dLdZu, WuxT) + mmul(dLdZc, WcxT) + mmul(dLdZr, WrxT)); // [bS, iS] + + dLdhLast->assign(*dLdh * u + mmul(dLdZu, WuhT) + mmul(dLdZc * r, WchT) + mmul(dLdZr, WrhT)); // [bS, nU] + + dLdWrx.assign(mmul(xT, dLdZr)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWrh.assign(mmul(hLastT, dLdZr)); // [nU, bS] × [bS, nU] = [nU, nU] + dLdWux.assign(mmul(xT, dLdZu)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWuh.assign(mmul(hLastT, dLdZu)); // [nU, bS] × [bS, nU] = [nU, nU] + + dLdWcx.assign(mmul(xT, dLdZc)); // [iS, bS] × [bS, nU] = [iS, nU] + dLdWch.assign(mmul((r * *hLast).transpose(), dLdZc)); // [nU, bS] × [bS, nU] = [nU, nU] + + dLdbr.assign(dLdZr.reduceAlongDimension(reduce::Sum, {0})); // [nU] + dLdbu.assign(dLdZu.reduceAlongDimension(reduce::Sum, {0})); // [nU] + + dLdbc->assign(dLdZc.reduceAlongDimension(reduce::Sum, {0})); // [nU] +} + + +////////////////////////////////////////////////////////////////////////// +void gruCellBp(sd::LaunchContext* context, + const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, const NDArray* gates, + NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { + + //Inputs: + // x input [bS, nIn] + // hI previous cell output [bS, nOut], that nIn at previous time step t-1 + // Wx input-to-hidden weights - [nIn, 3*nOut] + // Wh hidden-to-hidden weights - [nOut, 3*nOut] + // b biases, [3*nOut] - reset and update gates + // dLdh gradient vs. ff output, [bS, nOut] + + //Outputs: + // dLdx gradient vs. x, [bS, nIn], + // dLdhI gradient vs. hI, [bS, nOut] + // dLdWx gradient vs. W, [nIn, 3*nOut] + // dLdWh gradient vs. Wc, [nOut, 3*nOut] + // dLdb gradient vs. b [3*nOut] + + // 3*nOut means following sequence: reset, update, cell + + // * means element-wnIne product or so called Hadamard product + // × means matrix multiplication + + // formulas: + // zr = x × Wxr + hI × Whr + br + // zu = x × Wxu + hI × Whu + bu + // r = sigmoid(zr) + // u = sigmoid(zu) + // zc = x × Wxc + (r * hI) × Whc + bc + // c = tanh(zc) + // h = (1-u)*c + u*hI + + // dLdhI += dLdh; [bS, nOut] + + + // dhdc = 1 - u [bS, nOut] + // dhdu = -c + hI [bS, nOut] + + // dcdzc = 1 - c*c; [bS, nOut] + // dudzu = u*(1-u) [bS, nOut] + // drdzr = r(1-r) [bS, nOut] + + // dzcdr = (...*hI × WhcT) [bS, nOut] + + // dLdzr = dLdh*dhdc*dcdzc*dzcdr*drdzr = (dLdzc*hI*r(1-r) × WhcT); [bS, nOut] + // dLdzu = dLdh*dhdu*dudzu = dLdh*(hI-c)*u*(1-u) [bS, nOut] + // dLdzc = dLdh*dhdc*dcdzc = dLdh*(1-u)*(1-c*c) [bS, nOut] + + // dLdx = dLdzr × WxrT + dLdzu × WxuT + dLdzc × WxcT, [bs, nOut] × [nOut, nIn] + ... = [bS, nIn] + + // dLdhI = dLdzr × WhrT + dLdzu × WhuT + dLdzc × WhcT, [bs, nOut] × [nOut, nOut] + ... = [bS, nOut] + + // dLdWxr = xT × dLdzr [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxu = xT × dLdzu [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxc = xT × dLdzc [nIn, bS] x [bS, nOut] = [nIn, nOut] + + // dLdWhr = xT × dLdzr [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWhu = xT × dLdzu [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWhc = (r*hI)T × dLdzc [nOut, bS] x [bS, nOut] = [nOut, nOut] + + // dLdbr = dLdzr.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbu = dLdzu.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + // dLdbc = dLdzc.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] + + const int nOut = hI->sizeAt(1); + + NDArray dLdz = gates->ulike(); // [bS, 3*nOut] + + NDArray dLdzru = dLdz({0,0, 0,2*nOut}); // [bS, 2*nOut] + + NDArray dLdzr = dLdz({0,0, 0,nOut}); // [bS, nOut] + NDArray dLdzu = dLdz({0,0, nOut,2*nOut}); // [bS, nOut] + NDArray dLdzc = dLdz({0,0, 2*nOut,3*nOut}); // [bS, nOut] + + NDArray r = (*gates)({0,0, 0,nOut}); // [bS, nOut] + NDArray u = (*gates)({0,0, nOut,2*nOut}); // [bS, nOut] + NDArray c = (*gates)({0,0, 2*nOut,3*nOut}); // [bS, nOut] + + NDArray WhcT = (*Wh)({0,0, 2*nOut,3*nOut}).transpose(); + + if(dLdh) + *dLdhI += *dLdh; + + NDArray temp1 = 1 - u; // [bS, nOut] + + // dLdzc + dLdzc.assign(*dLdhI * temp1 * (1-c*c)); // [bS, nOut] + + // dLdzu + dLdzu.assign(*dLdhI * (*hI - c) * u * temp1); // [bS, nOut] + + // dLdzr + NDArray temp2 = dLdzc * (*hI) * r *(1-r); + MmulHelper::mmul(&temp2, &WhcT, &dLdzr); // [bS, nOut] x [nOut, nOut] = [bS, nOut] + + // dLdx + NDArray WxT = Wx->transpose(); + MmulHelper::mmul(&dLdz, &WxT, dLdx); // [bS, 3*nOut] x [3*nOut, nIn] = [bS, nIn] + + // dLdWx + *dLdWx += mmul(x->transpose(), dLdz); // [nIn, bS] x [bS, 3*nOut] = [nIn, 3*nOut] + + // dLdb + *dLdb += dLdz.reduceAlongDimension(reduce::Sum, {0}); // [bS, 3*nOut] -> reduce -> [3*nOut]; + + dLdzc *= r; + + // dLdhI + NDArray WhT = Wh->transpose(); + dLdhI->assign(*dLdhI*u + mmul(dLdz, WhT)); // [bS, 3*nOut] x [3*nOut, nOut] = [bS, nOut] + + // dLdWr + *dLdWh += mmul(hI->transpose(), dLdz); // [nOut, bS] x [bS, 3*nOut] = [nOut, 3*nOut] +} + + +////////////////////////////////////////////////////////////////////////// +void gruTimeLoopBp(sd::LaunchContext * context, + const NDArray* x, const NDArray* hI, const NDArray* Wx, const NDArray* Wh, const NDArray* b, const NDArray* dLdh, + NDArray* dLdx, NDArray* dLdhI, NDArray* dLdWx, NDArray* dLdWh, NDArray* dLdb) { + // sL means time steps + + // x input [sL, bS, nIn] + // hI initial cell output (at time step = 0) [bS, nOut] + // Wx input-to-hidden weights, [nIn, 3*nOut] + // Wh hidden-to-hidden weights, [nOut, 3*nOut] + // b biases, [3*nOut] + // dLdh gradient vs. ff output, [sL, bS, nOut] + + // dLdx gradient vs. x, [sL, bS, nIn], + // dLdhI gradient vs. hI, [bS, nOut] + // dLdWx gradient vs. W, [nIn, 3*nOut] + // dLdWh gradient vs. Wc, [nOut, 3*nOut] + // dLdb gradient vs. b [3*nOut] + + const int sL = x->sizeAt(0); + const int bS = x->sizeAt(1); + const int nOut = hI->sizeAt(1); + + NDArray gates(x->ordering(), {sL, bS, 3*nOut}, dLdh->dataType(), x->getContext()); + NDArray h(x->ordering(), {sL+1, bS, nOut}, dLdh->dataType(), x->getContext()); + + auto xSet = x->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] + auto dLdhSet = dLdh->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] + auto hSet = h.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] + auto gatesSet = gates.allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nOut] + auto dLdxSet = dLdx->allTensorsAlongDimension({1,2}); // sub-arrays with shape [bS, nIn] + + hSet.at(0)->assign(hI); + + // forward time loop + for (int t = 0; t < sL; ++t) + gruCell(context, xSet.at(t), hSet.at(t), Wx, Wh, b, gatesSet.at(t), hSet.at(t+1)); + + // backward time loop + for (int t = sL-1; t >= 0; --t) + gruCellBp(context, xSet.at(t), hSet.at(t), Wx, Wh, b, dLdhSet.at(t), gatesSet.at(t), + dLdxSet.at(t), dLdhI, dLdWx, dLdWh, dLdb); +} + + +} +} +} diff --git a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp index 9fce17c4b..bffd13128 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/lstmLayer.cpp @@ -189,54 +189,6 @@ static FORCEINLINE int getBatchTimeTotalIndex(const int dataFormat, const int sL return b * sL + t; // NTS, NST: shape [bS, sL, nIn], [bS, nIn, sL] } -////////////////////////////////////////////////////////////////////////// -// x{M,K} x y{K,N} = z{M,N}, dzdy{K,N,M,N} - Jacobian derivative -> if x.rankOf() == 2 -// x{K} x y{K,N} = z{N}, dzdy{K,N,N} - Jacobian derivative -> if x.rankOf() == 1 -static NDArray mmulJacobianWeightsDeriv(const int nOut, const NDArray& x) { - - std::vector outShape = x.rankOf() == 1 ? std::vector({x.sizeAt(0), nOut, nOut}) : std::vector({x.sizeAt(1), nOut, x.sizeAt(0), nOut}); - - NDArray dzdy(x.ordering(), outShape, x.dataType(), x.getContext()); - - if(x.rankOf() == 1) { - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - if(i1 == i2) - dzdy.p(i0,i1,i2, x.e(i0)); - else - dzdy.p(i0,i1,i2, 0); - } - } - } - }; - - samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1); - } - else { - auto func = PRAGMA_THREADS_FOR_3D { - - for (auto i0 = start_x; i0 < stop_x; ++i0) { - for (auto i1 = start_y; i1 < stop_y; ++i1) { - for (auto i2 = start_z; i2 < stop_z; ++i2) { - for (auto i3 = 0; i3 < dzdy.sizeAt(3); ++i3) { - if(i1 == i3) - dzdy.p(i0,i1,i2,i3, x.e(i2,i0)); - else - dzdy.p(i0,i1,i2,i3, 0); - } - } - } - } - }; - - samediff::Threads::parallel_for(func, 0,dzdy.sizeAt(0),1, 0,dzdy.sizeAt(1),1, 0,dzdy.sizeAt(2),1); - } - - return dzdy; -} ////////////////////////////////////////////////////////////////////////// void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, @@ -245,25 +197,25 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, NDArray* h, NDArray* c) { // * -> means element-wise multiplication - // ^ -> means matrix multiplication + // × -> means matrix multiplication /************************ THIS IS NOT OPTIMAZED CODE ***********************************/ /** the objective is to provide math-readable code **/ // equations (no peephole connections) - // it = σ(Wxi ^ xt + Wri ^ ht-1 + bi) - // ft = σ(Wxf ^ xt + Wrf ^ ht-1 + bf) - // c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc) + // it = σ(Wxi × xt + Wri × ht-1 + bi) + // ft = σ(Wxf × xt + Wrf × ht-1 + bf) + // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) // ct = ft * ct-1 + it * c't - // ot = σ(Wxo ^ xt + Wro ^ ht-1 + bo) + // ot = σ(Wxo × xt + Wro × ht-1 + bo) // ht = ot * tanh(ct) // equations (peephole connections are present) - // it = σ(Wxi ^ xt + Wri ^ ht-1 + Wpi * ct-1 + bi) - // ft = σ(Wxf ^ xt + Wrf ^ ht-1 + Wpf * ct-1 + bf) - // c't = tanh(Wxc ^ xt + Wrc ^ ht-1 + bc) + // it = σ(Wxi × xt + Wri × ht-1 + Wpi * ct-1 + bi) + // ft = σ(Wxf × xt + Wrf × ht-1 + Wpf * ct-1 + bf) + // c't = tanh(Wxc × xt + Wrc × ht-1 + bc) // ct = ft * ct-1 + it * c't - // ot = σ(Wxo ^ xt + Wro ^ ht-1 + Wpo * ct + bo) + // ot = σ(Wxo × xt + Wro × ht-1 + Wpo * ct + bo) // ht = ot * tanh(ct) @@ -399,7 +351,7 @@ void lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArray* Wr, ////////////////////////////////////////////////////////////////////////// void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdc, + const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp) { @@ -407,10 +359,10 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con /** the objective is to provide math-readable code **/ // equations (no peephole connections) - // zi = x ^ Wxi + hI ^ Wri + bi - // zf = x ^ Wxf + hI ^ Wrf + bf - // zg = x ^ Wxg + hI ^ Wrg + bg - // zo = x ^ Wxo + hI ^ Wro + bo + // zi = x × Wxi + hI × Wri + bi + // zf = x × Wxf + hI × Wrf + bf + // zg = x × Wxg + hI × Wrg + bg + // zo = x × Wxo + hI × Wro + bo // i = act(zi) // f = act(zf) // g = actC(zg) @@ -419,10 +371,10 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con // h = o * actH(c) // equations (peephole connections are present) - // zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi - // zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf - // zg = x ^ Wxg + hI ^ Wrg + bg - // zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo + // zi = x × Wxi + hI × Wri + cI * Wpi + bi + // zf = x × Wxf + hI × Wrf + cI * Wpf + bf + // zg = x × Wxg + hI × Wrg + bg + // zo = x × Wxo + hI × Wro + c * Wpo + bo // i = act(zi) // f = act(zf) // g = actC(zg) @@ -449,18 +401,19 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con // params[11] - beta value for output activation // INPUTS: - // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr - // Wx - input weights [nIn, 4*nOut] - // Wr - recurrent weights [nOut, 4*nOut] - // b - biases [4*nOut], optional, may be nullptr - // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr - // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr - // Wp - peephole weights [3*nOut], optional, may be nullptr - // dLdh - loss derivative with respect to h, [bS, nOut] or [nOut] if seqLen != nullptr - // dLdc - loss derivative with respect to c, [bS, nOut] or [nOut] if seqLen != nullptr - // z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] - // a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] - // c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut] + // x - current input at time t, [bS, nIn] or [nIn] if seqLen != nullptr + // Wx - input weights [nIn, 4*nOut] + // Wr - recurrent weights [nOut, 4*nOut] + // b - biases [4*nOut], optional, may be nullptr + // hI - (ht-1) previous (initial) output at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr + // cI - (ct-1) previous (initial) cell state at time t-1, [bS, nOut] or [nOut] if seqLen != nullptr + // Wp - peephole weights [3*nOut], optional, may be nullptr + // dLdh - loss derivative with respect to h at each time step, [bS, nOut] or [nOut] if seqLen != nullptr + // dLdhL - loss derivative with respect to h at last time step, [bS, nOut] or [nOut] if seqLen != nullptr + // dLdcL - loss derivative with respect to c at last time step, [bS, nOut] or [nOut] if seqLen != nullptr + // z - zi,zf,zg,zo taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] + // a - i,f,g,o taken from ff outputs to reduce amount of calculations in bp, [bS, 4*nOut] + // c - taken from ff outputs to reduce amount of calculations in bp, [bS, nOut] // OUTPUTS: // dLdx - loss derivative with respect to x, [bS, nIn] or [nIn] if seqLen != nullptr @@ -485,19 +438,19 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con // dLdzg = dLdcI*dcdg*dgdzg; [bS, nOut](or[nOut]) // dLdzo = dLdhI*dhdo*dodzo; [bS, nOut](or[nOut]) - // dLdx = dLdzi^WxiT + dLdzf^WxfT + dLdzg^WxgT + dLdzo^WxoT, [bS, nIn] - // dLdhI = dLdzi^WriT + dLdzf^WrfT + dLdzg^WrgT + dLdzo^WroT, [bS, nOut] + // dLdx = dLdzi×WxiT + dLdzf×WxfT + dLdzg×WxgT + dLdzo×WxoT, [bS, nIn] + // dLdhI = dLdzi×WriT + dLdzf×WrfT + dLdzg×WrgT + dLdzo×WroT, [bS, nOut] // dLdcI = dLdcI*dcdcI, [bS, nOut] - // dLdWxi = xT^dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxf = xT^dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxg = xT^dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWxo = xT^dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxi = xT×dLdzi [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxf = xT×dLdzf [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxg = xT×dLdzg [nIn, bS] x [bS, nOut] = [nIn, nOut] + // dLdWxo = xT×dLdzo [nIn, bS] x [bS, nOut] = [nIn, nOut] - // dLdWri = hIT^dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWrf = hIT^dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWrg = hIT^dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut] - // dLdWro = hIT^dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWri = hIT×dLdzi [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWrf = hIT×dLdzf [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWrg = hIT×dLdzg [nOut, bS] x [bS, nOut] = [nOut, nOut] + // dLdWro = hIT×dLdzo [nOut, bS] x [bS, nOut] = [nOut, nOut] // dLdbi = dLdzi.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] // dLdbf = dLdzf.reduce_sum_along_0_axis [bS, nOut] -> reduce -> [nOut] @@ -563,10 +516,12 @@ void lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, con if(dLdh) *dLdhI += *dLdh; - if(dLdc) - *dLdcI += *dLdc; - else - *dLdcI += *dLdhI * dhdc; + if(dLdhL) + *dLdhI += *dLdhL; + if(dLdcL) + *dLdcI += *dLdcL; + + *dLdcI += *dLdhI * dhdc; dLdzi *= *dLdcI; // [bS, nOut](or[nOut]) dLdzf *= *dLdcI; // [bS, nOut](or[nOut]) @@ -662,25 +617,27 @@ void lstmLayerTimeLoop(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const std::vector shapeOut = {bS, nOut}; + const auto type = h ? h->dataType() : (hL ? hL->dataType() : cL->dataType()); + auto h0 = const_cast(hI); if(!hI) { - h0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + h0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); h0->nullify(); } auto c0 = const_cast(cI); if(!cI) { - c0 = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + c0 = new NDArray(x->ordering(), shapeOut, type, x->getContext()); c0->nullify(); } auto ct = cL; if(!cL) - ct = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + ct = new NDArray(x->ordering(), shapeOut, type, x->getContext()); auto ht = hL; if(!h && !hL) - ht = new NDArray(x->ordering(), shapeOut, x->dataType(), x->getContext()); + ht = new NDArray(x->ordering(), shapeOut, type, x->getContext()); // create sets of required (depends on seqLen presence) sub-arrays std::vector dims; @@ -989,17 +946,19 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const int bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(1); const int nOut = Wx->sizeAt(-1) / 4; + const auto type = dLdh ? dLdh->dataType() : (dLdhL ? dLdhL->dataType() : dLdcL->dataType()); + auto dLdh0 = dLdhI; if(!hI) - dLdh0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically + dLdh0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically auto dLdc0 = dLdcI; if(!cI) - dLdc0 = new NDArray(x->ordering(), {bS, nOut}, x->dataType(), x->getContext()); // this constructor nullifies array automatically + dLdc0 = new NDArray(x->ordering(), {bS, nOut}, type, x->getContext()); // this constructor nullifies array automatically - NDArray z(x->ordering(), {sL, bS, 4*nOut}, x->dataType(), x->getContext()); + NDArray z(x->ordering(), {sL, bS, 4*nOut}, type, x->getContext()); NDArray a = z.ulike(); - NDArray h(x->ordering(), {sL+1, bS, nOut}, x->dataType(), x->getContext()); + NDArray h(x->ordering(), {sL+1, bS, nOut}, type, x->getContext()); NDArray c = h.ulike(); // create sets of required (depends on seqLen presence) sub-arrays @@ -1041,9 +1000,9 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(dLdh) dLdhSet = new ResultSet(dLdh->allTensorsAlongDimension(dims)); // sub-arrays with shape [nOut] - if(!dLdh && dLdhL) + if(dLdhL) dLdhLSet = new ResultSet(dLdhL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] - if(!dLdh && !dLdhL) + if(dLdcL) dLdcLSet = new ResultSet(dLdcL->allTensorsAlongDimension({1})); // sub-arrays with shape [nOut] } @@ -1054,13 +1013,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(!seqLen) { // seqLen is absent if(hI) - h({0,1, 0,0, 0,0}).assign(hI); + hSet->at(0)->assign(hI); else - h({0,1, 0,0, 0,0}).nullify(); + hSet->at(0)->nullify(); if(cI) - c({0,1, 0,0, 0,0}).assign(cI); + cSet->at(0)->assign(cI); else - c({0,1, 0,0, 0,0}).nullify(); + cSet->at(0)->nullify(); // ff for (int t = 0; t < sL; ++t) @@ -1068,9 +1027,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = sL-1; t >= 0; --t) { - const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == sL-1 ? dLdhL : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-1 ? dLdcL : nullptr); - lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr; + const NDArray* dLdhhL = (t == sL-1 && dLdhL) ? dLdhL : nullptr; + const NDArray* dLdccL = (t == sL-1 && dLdcL) ? dLdcL : nullptr; + lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t), cSet->at(t), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t), aSet->at(t), cSet->at(t+1), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); } } @@ -1086,13 +1046,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, } if(hI) - h({0,1, e,e+1, 0,0}).assign(hISet->at(e)); + hSet->at(e)->assign(hISet->at(e)); else - h({0,1, e,e+1, 0,0}).nullify(); + hSet->at(e)->nullify(); if(cI) - c({0,1, e,e+1, 0,0}).assign(cISet->at(e)); + cSet->at(e)->assign(cISet->at(e)); else - c({0,1, e,e+1, 0,0}).nullify(); + cSet->at(e)->nullify(); // ff for (int t = 0; t < limit; ++t) @@ -1102,9 +1062,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = limit-1; t >= 0; --t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == limit-1 && dLdhL ? dLdhLSet->at(e) : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == limit-1 ? dLdcLSet->at(e) : nullptr); - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = (t == limit-1 && dLdhL) ? dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = (t == limit-1 && dLdcL) ? dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at(t*bS + e), cSet->at(t*bS + e), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at((t+1)*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); } @@ -1119,13 +1080,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, if(!seqLen) { // backward or bidirectional, seqLen is absent if(hI) - h({sL,sL+1, 0,0, 0,0}).assign(hI); + hSet->at(sL)->assign(hI); else - h({sL,sL+1, 0,0, 0,0}).nullify(); + hSet->at(sL)->nullify(); if(cI) - c({sL,sL+1, 0,0, 0,0}).assign(cI); + cSet->at(sL)->assign(cI); else - c({sL,sL+1, 0,0, 0,0}).nullify(); + cSet->at(sL)->nullify(); // ff for (int t = sL-1; t >= 0; --t) @@ -1133,9 +1094,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = 0; t < sL; ++t) { - const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : (t == 0 ? dLdhL : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcL : nullptr); - lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(t) : nullptr; + const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhL : nullptr; + const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcL : nullptr; + lstmLayerCellBp(xSet->at(t), Wx, Wr, b, hSet->at(t+1), cSet->at(t+1), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t), aSet->at(t), cSet->at(t), params, dLdxSet->at(t), dLdWx, dLdWr, dLdh0, dLdc0, dLdb, dLdWp); } } @@ -1151,13 +1113,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, } if(hI) - h({sL,sL+1, e,e+1, 0,0}).assign(hISet->at(e)); + hSet->at(sL*bS + e)->assign(hISet->at(e)); else - h({sL,sL+1, e,e+1, 0,0}).nullify(); + hSet->at(sL*bS + e)->nullify(); if(cI) - c({sL,sL+1, e,e+1, 0,0}).assign(cISet->at(e)); + cSet->at(sL*bS + e)->assign(cISet->at(e)); else - c({sL,sL+1, e,e+1, 0,0}).nullify(); + cSet->at(sL*bS + e)->nullify(); // ff for (int t = sL - 1; t >= sL-limit; --t) @@ -1167,9 +1129,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = sL-limit; t < sL; ++t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == sL-limit && dLdhL ? dLdhLSet->at(e) : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == sL-limit ? dLdcLSet->at(e) : nullptr); - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = (t == sL-limit && dLdhL) ? dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = (t == sL-limit && dLdcL) ? dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); } @@ -1206,9 +1169,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // bp for (int t = 0; t < limit; ++t) { const auto ind = getBatchTimeTotalIndex(dataFormat, sL, bS, t, e); - const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : (t == 0 && dLdhL ? dLdhLSet->at(e) : nullptr); - const NDArray* dLdcc = dLdhh ? nullptr : (t == 0 ? dLdcLSet->at(e) : nullptr); - lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdcc, + const NDArray* dLdhh = dLdh ? dLdhSet->at(ind) : nullptr; + const NDArray* dLdhhL = (t == 0 && dLdhL) ? dLdhLSet->at(e) : nullptr; + const NDArray* dLdccL = (t == 0 && dLdcL) ? dLdcLSet->at(e) : nullptr; + lstmLayerCellBp(xSet->at(ind), Wx, Wr, b, hSet->at((t+1)*bS + e), cSet->at((t+1)*bS + e), Wp, dLdhh, dLdhhL, dLdccL, zSet->at(t*bS + e), aSet->at(t*bS + e), cSet->at(t*bS + e), params, dLdxSet->at(ind), dLdWx, dLdWr, dLdh0Set->at(e), dLdc0Set->at(e), dLdb, dLdWp); } @@ -1248,10 +1212,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // /** the objective is to provide math-readable code **/ // // equations (no peephole connections) -// // zi = x ^ Wxi + hI ^ Wri + bi -// // zf = x ^ Wxf + hI ^ Wrf + bf -// // zg = x ^ Wxg + hI ^ Wrg + bg -// // zo = x ^ Wxo + hI ^ Wro + bo +// // zi = x × Wxi + hI × Wri + bi +// // zf = x × Wxf + hI × Wrf + bf +// // zg = x × Wxg + hI × Wrg + bg +// // zo = x × Wxo + hI × Wro + bo // // i = act(zi) // // f = act(zf) // // g = actC(zg) @@ -1260,10 +1224,10 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // h = o * actH(c) // // equations (peephole connections are present) -// // zi = x ^ Wxi + hI ^ Wri + cI * Wpi + bi -// // zf = x ^ Wxf + hI ^ Wrf + cI * Wpf + bf -// // zg = x ^ Wxg + hI ^ Wrg + bg -// // zo = x ^ Wxo + hI ^ Wro + c * Wpo + bo +// // zi = x × Wxi + hI × Wri + cI * Wpi + bi +// // zf = x × Wxf + hI × Wrf + cI * Wpf + bf +// // zg = x × Wxg + hI × Wrg + bg +// // zo = x × Wxo + hI × Wro + c * Wpo + bo // // i = act(zi) // // f = act(zf) // // g = actC(zg) @@ -1333,13 +1297,13 @@ void lstmLayerTimeLoopBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, // // oFactor = *dLdh*dhdzo [bS, nOut] // // tempC = dcdcI + Wp ? dcdzi*dzidcI + dcdzf*dzfdcI : 0; -// // tempIFE = dcdzi^WriT + dcdzf^WrfT + dcdzg^WrgT -// // tempO = dhdzo^WroT +// // tempIFE = dcdzi×WriT + dcdzf×WrfT + dcdzg×WrgT +// // tempO = dhdzo×WroT // // dhIdcI = dhdc_from_previous_time_step -// // dLdx = iFactor^WxiT + fFactor^WxfT + eFactor^WxgT + oFactor^WxoT, [bS, nIn] -// // dLdhI = iFactor^WriT + fFactor^WrfT + eFactor^WrgT + oFactor^WroT, [bS, nOut] +// // dLdx = iFactor×WxiT + fFactor×WxfT + eFactor×WxgT + oFactor×WxoT, [bS, nIn] +// // dLdhI = iFactor×WriT + fFactor×WrfT + eFactor×WrgT + oFactor×WroT, [bS, nOut] // // dLdcI = factor*tempC + dLdhI * dhIdcI, dhIdcI=0 if firstIter, [bS, nOut] // // dcdWxi(dcIdWxi) = dcdzi*dzidWxi + tempIFE*dhIdWxi + tempC*dcIdWxi, dcIdWxi=dhIdWxi= 0 if firstIter, [nIn, nOut, bS, nOut] diff --git a/libnd4j/include/ops/declarable/helpers/lstmLayer.h b/libnd4j/include/ops/declarable/helpers/lstmLayer.h index 3a2d173b5..29c434865 100644 --- a/libnd4j/include/ops/declarable/helpers/lstmLayer.h +++ b/libnd4j/include/ops/declarable/helpers/lstmLayer.h @@ -42,7 +42,7 @@ void ND4J_EXPORT lstmLayerCell(const NDArray* x, const NDArray* Wx, const NDArra ////////////////////////////////////////////////////////////////////////// void ND4J_EXPORT lstmLayerCellBp(const NDArray* x, const NDArray* Wx, const NDArray* Wr, const NDArray* b, const NDArray* hI, const NDArray* cI, const NDArray* Wp, - const NDArray* dLdh, const NDArray* dLdc, + const NDArray* dLdh, const NDArray* dLdhL, const NDArray* dLdcL, const NDArray* z, const NDArray* a, const NDArray* c, const std::vector& params, NDArray* dLdx, NDArray* dLdWx, NDArray* dLdWr, NDArray* dLdhI, NDArray* dLdcI, NDArray* dLdb, NDArray* dLdWp); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index cee574dec..4052e260d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -2084,11 +2084,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = false; // dLdh per each time step + const auto retFullSeq = true; // dLdh per each time step const auto retLastH = true; // output at last time step - const auto retLastC = true; // cells state at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2097,6 +2097,7 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE); NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); @@ -2113,12 +2114,12 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::SUM, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); ASSERT_TRUE(isGradCorrect); } @@ -2131,63 +2132,6 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_2) { const int nIn = 2; const int nOut = 3; - const int dataFormat = 0; // [sL,bS,nIn] - const int directionMode = 0; // forward - const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates - const int cellAct = 0; // tanh activation for cell state - const int outAct = 0; // tanh activation for output - - const bool hasBiases = true; // biases array is provided - const bool hasSeqLen = false; // seqLen array is not provided - const auto hasInitH = true; // initial output is provided - const auto hasInitC = true; // initial cell state is provided - const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = false; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = false; // output at last time step - const auto retLastC = true; // cells state at last time step - - const double cellClip = 0.5; // do not apply clipping - - NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); - NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); - NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); - NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); - NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); - NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); - NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); - - x.linspace(-2,0.1); - hI.linspace(-1.5,0.1); - cI.linspace(0.7,-0.1); - Wx.linspace(1,-0.1); - Wr.linspace(-1,0.1); - Wp.linspace(0.2,0.2); - b.linspace(1,-0.15); - - std::vector tArgs = {cellClip}; - std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; - std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; - - const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); - - sd::ops::lstmLayer opFF; - sd::ops::lstmLayer_bp opBP; - - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN); - - ASSERT_TRUE(isGradCorrect); -} - -/////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { - - const int sL = 3; - const int bS = 2; - const int nIn = 2; - const int nOut = 3; - const int dataFormat = 1; // [bS,sL,nIn] const int directionMode = 0; // forward const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates @@ -2199,11 +2143,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] - const auto retLastH = false; // output at last time step + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = false; // output at last time step const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2233,13 +2177,13 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { const int sL = 4; const int bS = 3; @@ -2258,10 +2202,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = false; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2272,6 +2216,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2286,18 +2232,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { const int sL = 3; const int bS = 2; @@ -2315,11 +2261,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = false; // dLdh per each time step + const auto retFullSeq = true; // dLdh per each time step const auto retLastH = true; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2328,7 +2274,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE); NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2343,18 +2291,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { const int sL = 3; const int bS = 2; @@ -2373,10 +2321,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = false; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2387,6 +2335,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2401,18 +2351,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { const int sL = 3; const int bS = 2; @@ -2430,11 +2380,11 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { const auto hasInitH = true; // initial output is provided const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent - const auto retFullSeq = false; // dLdh per each time step + const auto retFullSeq = true; // dLdh per each time step const auto retLastH = true; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2444,7 +2394,9 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2459,18 +2411,24 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { const int sL = 3; const int bS = 2; @@ -2489,10 +2447,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = false; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {bS,sL,nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2503,6 +2461,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); NDArray dLdh('c', {bS,sL,2*nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2517,18 +2477,24 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { +TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { const int sL = 3; const int bS = 2; @@ -2547,10 +2513,10 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { const auto hasInitC = true; // initial cell state is provided const auto hasPH = true; // peephole connections are absent const auto retFullSeq = true; // dLdh per each time step - const auto retLastH = false; // output at last time step - const auto retLastC = false; // cells state at last time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step - const double cellClip = 0.5; // do not apply clipping + const double cellClip = 0.5; // clipping NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); @@ -2561,6 +2527,8 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); x.linspace(-2,0.1); hI.linspace(-1.5,0.1); @@ -2575,12 +2543,18 @@ TEST_F(DeclarableOpsTests13, lstmLayer_bp_9) { std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); - const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); sd::ops::lstmLayer opFF; sd::ops::lstmLayer_bp opBP; - const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}, {0., 1.}, GradCheck::LossFunc::MEAN, {0}); + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); ASSERT_TRUE(isGradCorrect); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 5fffa73c5..3d86cd92b 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -1904,6 +1904,7 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP16) { const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1,0}); ASSERT_TRUE(isGradCorrect); } + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) { @@ -1922,3 +1923,68 @@ TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) { ASSERT_TRUE(isGradCorrect); } + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, gru_1) { + + const int sL = 3; + const int bS = 2; + const int nIn = 5; + const int nOut = 4; + + + NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::FLOAT32); + NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {3*nOut}, sd::DataType::FLOAT32); + + NDArray expH('c', {sL, bS, nOut}, {-1.681847, -1.062565, -0.443283, 0.175998,0.837823, 1.488041, 2.13826 , 2.788478, -0.888747, -0.491826, -0.094907, 0.302014, + 0.751355, 1.182715, 1.614075, 2.045434, -0.388876, -0.126716, 0.135444, 0.397604,0.710558, 1.002922, 1.295287, 1.587651}, sd::DataType::FLOAT32); + + Wx = 0.003; + Wh = 0.006; + b = 0.5; + + NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE); + + sd::ops::gru op; + auto results = op.evaluate({&x, &hI, &Wx, &Wh, &b}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* h = results.at(0); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, gru_bp_1) { + + const int sL = 3; + const int bS = 2; + const int nIn = 5; + const int nOut = 4; + + + NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::DOUBLE); + NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {3*nOut}, sd::DataType::DOUBLE); + + NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE); + + Wx.linspace(1,-0.1); + Wh.linspace(0.2,0.2); + b.linspace(1,-0.15); + + const OpArgsHolder argsHolderFF({&x, &hI, &Wx, &Wh, &b}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &hI, &Wx, &Wh, &b, &dLdh}, {}, {}); + + sd::ops::gru opFF; + sd::ops::gru_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); +} From 12ba1fa4062c7fbfb9106971548a8b0ed66e8adc Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 16 Apr 2020 13:25:13 +0300 Subject: [PATCH 04/18] Few minor fixes (#381) * - 1D indexing fix - couple of new tests for 1D indexing Signed-off-by: raver119 * percentile fix + test Signed-off-by: raver119 * wrong signature used in test Signed-off-by: raver119 --- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 18 ++++- .../java/org/nd4j/nativeblas/Nd4jCuda.java | 2 +- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 74 ++++++++++++++++++- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 7 ++ .../linalg/shape/indexing/IndexingTestsC.java | 19 +++++ 5 files changed, 116 insertions(+), 4 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 07a2bf9b8..46daa869b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -2045,8 +2045,18 @@ public abstract class BaseNDArray implements INDArray, Iterable { throw new ND4JIllegalArgumentException("Indices must be a vector or matrix."); } - if(indices.rows() == rank()) { - INDArray ret = Nd4j.create(indices.dataType(), indices.columns()); + if (rank() == 1) { + Preconditions.checkArgument(indices.rank() <= 1, "For 1D vector indices must be either scalar or vector as well"); + val ret = Nd4j.createUninitialized(this.dataType(), indices.length()); + for (int e = 0; e < indices.length(); e++) { + val idx = indices.getLong(e); + val value = getDouble(idx); + ret.putScalar(e, value); + } + + return ret; + } else if(indices.rows() == rank()) { + INDArray ret = Nd4j.create(this.dataType(), indices.columns()); for(int i = 0; i < indices.columns(); i++) { int[] specifiedIndex = indices.getColumn(i).dup().data().asInt(); @@ -5391,6 +5401,10 @@ public abstract class BaseNDArray implements INDArray, Iterable { return sorted.getDouble(sorted.length() - 1); double pos = (quantile.doubleValue() / 100.0) * (double) (sorted.length() + 1); + if (pos < 1) + return sorted.getDouble(0); + else if (pos >= sorted.length()) + return sorted.getDouble(sorted.length() - 1); double fposition = FastMath.floor(pos); int position = (int)fposition; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index 8f30cdd82..6753b5ea1 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -1,4 +1,4 @@ -// Targeted by JavaCPP version 1.5.3-SNAPSHOT: DO NOT EDIT THIS FILE +// Targeted by JavaCPP version 1.5.3: DO NOT EDIT THIS FILE package org.nd4j.nativeblas; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index e96325460..adddf5e42 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -1,4 +1,4 @@ -// Targeted by JavaCPP version 1.5.3-SNAPSHOT: DO NOT EDIT THIS FILE +// Targeted by JavaCPP version 1.5.3: DO NOT EDIT THIS FILE package org.nd4j.nativeblas; @@ -16040,6 +16040,41 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif +// #if NOT_EXCLUDED(OP_lstmLayerCell) + @Namespace("sd::ops") public static class lstmLayerCell extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayerCell(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayerCell(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayerCell position(long position) { + return (lstmLayerCell)super.position(position); + } + + public lstmLayerCell() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif +// #if NOT_EXCLUDED(OP_lstmLayerCell) + @Namespace("sd::ops") public static class lstmLayerCellBp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayerCellBp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayerCellBp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayerCellBp position(long position) { + return (lstmLayerCellBp)super.position(position); + } + + public lstmLayerCellBp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + ////////////////////////////////////////////////////////////////////////// /** @@ -16169,6 +16204,25 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + ////////////////////////////////////////////////////////////////////////// +// #if NOT_EXCLUDED(OP_lstmLayer) + @Namespace("sd::ops") public static class lstmLayer_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public lstmLayer_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public lstmLayer_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public lstmLayer_bp position(long position) { + return (lstmLayer_bp)super.position(position); + } + + public lstmLayer_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + ////////////////////////////////////////////////////////////////////////// /** @@ -16336,6 +16390,24 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif +// #if NOT_EXCLUDED(OP_gru) + @Namespace("sd::ops") public static class gru_bp extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public gru_bp(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public gru_bp(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public gru_bp position(long position) { + return (gru_bp)super.position(position); + } + + public gru_bp() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + ////////////////////////////////////////////////////////////////////////// /** * Implementation of operation "static RNN time sequences" with peep hole connections: diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 162e123b8..886384f4a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -5607,6 +5607,13 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(exp, array.percentileNumber(75)); } + @Test + public void testPercentile5() { + val array = Nd4j.createFromArray(new int[]{1, 1982}); + val perc = array.percentileNumber(75); + assertEquals(1982.f, perc.floatValue(), 1e-5f); + } + @Test public void testTadPercentile1() { INDArray array = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java index 33a64c291..24c6b30d4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/indexing/IndexingTestsC.java @@ -17,6 +17,7 @@ package org.nd4j.linalg.shape.indexing; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ErrorCollector; @@ -190,6 +191,24 @@ public class IndexingTestsC extends BaseNd4jTest { assertTrue(last10b.getDouble(i) == 20 + i); } + @Test + public void test1dSubarray_1() { + val data = Nd4j.linspace(DataType.FLOAT,0, 10, 1); + val exp = Nd4j.createFromArray(new float[]{3.f, 4.f}); + val dataAtIndex = data.get(NDArrayIndex.interval(3, 5)); + + assertEquals(exp, dataAtIndex); + } + + @Test + public void test1dSubarray_2() { + val data = Nd4j.linspace(DataType.FLOAT,1, 10, 1); + val exp = Nd4j.createFromArray(new float[]{4.f, 6.f}); + val dataAtIndex = data.get(Nd4j.createFromArray(new int[]{3, 5})); + + assertEquals(exp, dataAtIndex); + } + @Test public void testGet() { // System.out.println("Testing sub-array put and get with a 3D array ..."); From 3d15706ffabcea9a36f63076d4ac037cfbc0aba2 Mon Sep 17 00:00:00 2001 From: Oleh Date: Thu, 16 Apr 2020 14:53:56 +0300 Subject: [PATCH 05/18] Lin_space operation improve (#373) * libnd4j update linspace op Signed-off-by: Oleg * libnd4j #8513 update lin_space op, tests added Signed-off-by: Oleg * - minor linspace tweaks (num_elements now iArg) - java linspace updates - couple of additional tests for linspace Signed-off-by: raver119 * roll back timeout change Signed-off-by: raver119 Co-authored-by: raver119 --- .../declarable/generic/tensor/lin_space.cpp | 32 +++++++++++++------ .../ops/declarable/headers/parity_ops.h | 10 ++++-- .../layers_tests/DeclarableOpsTests10.cpp | 28 ++++++++++++++++ .../layers_tests/JavaInteropTests.cpp | 14 ++++++++ .../linalg/api/ops/impl/shape/Linspace.java | 18 ++++++++++- .../ops/executioner/CudaExecutioner.java | 2 +- .../nativecpu/ops/NativeOpExecutioner.java | 2 +- .../java/org/nd4j/nativeblas/Nd4jCpu.java | 8 +++-- .../nd4j/linalg/custom/CustomOpsTests.java | 11 +++++++ 9 files changed, 108 insertions(+), 17 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp b/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp index 54fd8fb0e..374456be6 100644 --- a/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp +++ b/libnd4j/include/ops/declarable/generic/tensor/lin_space.cpp @@ -26,24 +26,38 @@ namespace sd { namespace ops { - CUSTOM_OP_IMPL(lin_space, 3, 1, false, 0, 0) { - auto output = OUTPUT_VARIABLE(0); - auto start = INPUT_VARIABLE(0); - auto finish = INPUT_VARIABLE(1); - auto numOfElements = INPUT_VARIABLE(2); + CUSTOM_OP_IMPL(lin_space, 0, 1, false, 0, 0) { - if (numOfElements->e(0) == 1) { + auto output = OUTPUT_VARIABLE(0); + + const int nInputs = block.width(); + bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0)); + + REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT()); + + auto start = (nInputs > 0) ? INPUT_VARIABLE(0)->e(0) : static_cast(T_ARG(0)); + auto finish = (nInputs > 0) ? INPUT_VARIABLE(1)->e(0) : static_cast(T_ARG(1)); + auto numOfElements = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) : static_cast(I_ARG(0)); + + if (numOfElements == 1) { output->assign(start); return Status::OK(); } - output->linspace(start->e(0), (finish->e(0) - start->e(0)) / (numOfElements->e(0) - 1.)); + output->linspace(start, (finish - start) / ( numOfElements - 1.0 )); return Status::OK(); } DECLARE_SHAPE_FN(lin_space) { - auto dataType = ArrayOptions::dataType(inputShape->at(0)); - Nd4jLong steps = INPUT_VARIABLE(2)->e(0); + + const int nInputs = block.width(); + bool bInputs = (3 == nInputs || 3 == block.numI() || (2 == block.numT() && block.numI() > 0)); + REQUIRE_TRUE(bInputs, 0, "lin_space OP: Have to be supplied correct inputs, input size or T_ARG size have to be equal 3, but got inputs - %i, T_ARGS - %i!", nInputs, block.numT() ); + + + auto dataType = (nInputs > 0) ? ArrayOptions::dataType(inputShape->at(0)) : ( block.numD() > 0 ? static_cast(D_ARG(0)) : DataType::FLOAT32) ; + Nd4jLong steps = (nInputs > 0) ? INPUT_VARIABLE(2)->e(0) : static_cast(I_ARG(0)); + return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(steps, dataType)); } diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index f3131c193..8fae1b63c 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -1433,16 +1433,20 @@ namespace sd { /** * lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space) * - * input params: + * optional input params: * 0 - startVal - NDArray scalar (float point) * 1 - finishVal - NDArray scalar (float point) * 2 - numOfElements - NDArray scalar (integer) - * + * Optional: + * T args + * 0 - startVal + * 1 - finishVal] + * 2 - numOfElements * output: * 0 - 1D NDArray with the same type as input and length as given with numOfElements param. */ #if NOT_EXCLUDED(OP_lin_space) - DECLARE_CUSTOM_OP(lin_space, 3, 1, false, 0, 0); + DECLARE_CUSTOM_OP(lin_space, 0, 1, false, 0, 0); #endif /** diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index 03e5ae53f..6d89bd182 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -2010,6 +2010,34 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) { ASSERT_TRUE(expect.equalsTo(res)); +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, LinSpace_Test2) { + + NDArray expect = NDArrayFactory::create({1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, + 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); + + sd::ops::lin_space op; + auto result = op.evaluate({}, {1, 12}, {23}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + ASSERT_EQ( res->dataType(), sd::DataType::FLOAT32 ); + ASSERT_TRUE(expect.equalsTo(res)); + +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, LinSpace_Test3) { + + NDArray expect('c', { 23 }, {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}, sd::DataType::DOUBLE ); + + sd::ops::lin_space op; + auto result = op.evaluate({}, {1, 12}, {23}, {}, { sd::DOUBLE }); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + + ASSERT_EQ( res->dataType(), expect.dataType()); + ASSERT_TRUE(expect.equalsTo(res)); + } //////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index 6f559230b..29c681544 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -1334,6 +1334,20 @@ TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) { ASSERT_EQ(Status::OK(), status); } +TEST_F(JavaInteropTests, test_linspace_shape_1) { + if (!Environment::getInstance()->isCPU()) + return; + + sd::ops::lin_space op; + double tArgs[2] = {1.0, 10.0}; + Nd4jLong iArgs = 10L; + int dArg = (int) sd::DataType::FLOAT32; + auto result = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 2, &iArgs, 1, nullptr, 0, &dArg, 1); + + ASSERT_EQ(1, result->size()); + delete result; +} + /* TEST_F(JavaInteropTests, Test_Results_Conversion_1) { auto pl = sd::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index 4bc3b3f63..a9f964844 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -42,6 +42,9 @@ import java.util.Map; public class Linspace extends DynamicCustomOp { private DataType dataType; + private double start; + private double stop; + private long elements; public Linspace(SameDiff sameDiff, DataType dataType, double start, double stop, long number) { this(sameDiff, sameDiff.constant(start), sameDiff.constant(stop), sameDiff.constant(number), dataType); @@ -54,7 +57,7 @@ public class Linspace extends DynamicCustomOp { } public Linspace(DataType dataType, double start, double stop, long number) { - this(dataType, Nd4j.scalar(start), Nd4j.scalar(stop), Nd4j.scalar(number)); + this(start, stop, number, dataType); } public Linspace(DataType dataType, INDArray start, INDArray stop, INDArray number) { @@ -67,6 +70,19 @@ public class Linspace extends DynamicCustomOp { addDArgument(dataType); } + public Linspace(double start, double stop, long number, @NonNull DataType dataType) { + super(new INDArray[]{}, null); + this.dataType = dataType; + addDArgument(dataType); + + this.start = start; + this.stop = stop; + this.elements = number; + + addTArgument(this.start, this.stop); + addIArgument(elements); + } + public Linspace(){ } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index bda208ce7..b21123085 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -1947,7 +1947,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { val result = new ArrayList(); int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments(); - if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) { + if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: number of input args was 0", op.getClass().getName()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index f0488636f..93ff5cf52 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -1754,7 +1754,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { val result = new ArrayList(); int nIn = opContext != null ? opContext.numInputArguments() : op.numInputArguments(); - if(nIn == 0 && op.getDescriptor().getNumInputs() != -2) { + if(nIn == 0 && op.getDescriptor().getNumInputs() >= 1) { if(log.isTraceEnabled()){ log.trace("Could not calculate output shape for op {}: number of input args was 0", op.getClass().getName()); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index adddf5e42..cd18e0f18 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -20475,11 +20475,15 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * lin_space - op porting from TF (https://www.tensorflow.org/api_docs/python/tf/lin_space) * - * input params: + * optional input params: * 0 - startVal - NDArray scalar (float point) * 1 - finishVal - NDArray scalar (float point) * 2 - numOfElements - NDArray scalar (integer) - * + * Optional: + * T args + * 0 - startVal + * 1 - finishVal] + * 2 - numOfElements * output: * 0 - 1D NDArray with the same type as input and length as given with numOfElements param. */ diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 01dc83ee4..04ebaa0d7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -37,6 +37,7 @@ import org.nd4j.linalg.api.ops.impl.image.ResizeBilinear; import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; import org.nd4j.linalg.api.ops.impl.shape.Create; +import org.nd4j.linalg.api.ops.impl.shape.Linspace; import org.nd4j.linalg.api.ops.impl.shape.OnesLike; import org.nd4j.linalg.api.ops.impl.shape.SequenceMask; import org.nd4j.linalg.api.ops.impl.transforms.Cholesky; @@ -1803,6 +1804,16 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(ret[0], in); } + + @Test + public void testLinspaceSignature_1() throws Exception { + val array1 = Nd4j.exec(new Linspace(DataType.FLOAT, Nd4j.scalar(1.0f), Nd4j.scalar(10.f), Nd4j.scalar(10L)))[0]; + val array2 = Nd4j.exec(new Linspace(DataType.FLOAT, 1.0f, 10.f, 10L))[0]; + + assertEquals(array1.dataType(), array2.dataType()); + assertEquals(array1, array2); + } + @Test public void testLogdet() { INDArray x = Nd4j.createFromArray(new double[]{ From bb9cdb251e9a1d739e06ffd127370108b694e6ce Mon Sep 17 00:00:00 2001 From: Paul Dubs Date: Fri, 17 Apr 2020 02:19:21 +0200 Subject: [PATCH 06/18] Small ConvolutionalIterationListener improvements (#382) * Add a message to the runtime exception Signed-off-by: Paul Dubs * Output Convolutions as PNG instead of JPG A lossless encoding is useful in this case, as it allows small details to be preserved Signed-off-by: Paul Dubs --- .../ui/weights/ConvolutionalIterationListener.java | 2 +- .../module/convolutional/ConvolutionalListenerModule.java | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java index 037c44b40..4aae630ca 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/src/main/java/org/deeplearning4j/ui/weights/ConvolutionalIterationListener.java @@ -140,7 +140,7 @@ public class ConvolutionalIterationListener extends BaseTrainingListener { ComputationGraph l = (ComputationGraph) model; Layer[] layers = l.getLayers(); if(layers.length != activations.size()) - throw new RuntimeException(); + throw new RuntimeException("layers.length != activations.size(). Got layers.length="+layers.length+", activations.size()="+activations.size()); for( int i=0; i Date: Fri, 17 Apr 2020 14:41:49 +1000 Subject: [PATCH 07/18] Switch Java-based updater implementations to C++ ops (#384) Signed-off-by: Alex Black --- .../api/ops/impl/updaters/AmsGradUpdater.java | 7 ++++-- .../api/ops/impl/updaters/NadamUpdater.java | 7 ++++-- .../nd4j/linalg/learning/AMSGradUpdater.java | 25 +++---------------- .../nd4j/linalg/learning/AdaDeltaUpdater.java | 12 +++------ .../nd4j/linalg/learning/AdaGradUpdater.java | 10 ++------ .../nd4j/linalg/learning/AdaMaxUpdater.java | 23 +++++------------ .../org/nd4j/linalg/learning/AdamUpdater.java | 20 +++------------ .../nd4j/linalg/learning/NadamUpdater.java | 19 +++----------- .../linalg/learning/NesterovsUpdater.java | 12 ++------- .../nd4j/linalg/learning/RmsPropUpdater.java | 6 ++--- .../org/nd4j/linalg/learning/SgdUpdater.java | 4 ++- 11 files changed, 40 insertions(+), 105 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java index 35af113ad..5e8db1cfd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/AmsGradUpdater.java @@ -30,11 +30,14 @@ public class AmsGradUpdater extends DynamicCustomOp { // } - public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, double lr, double beta1, double beta2, double epsilon, int iteration) { + public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, + double lr, double beta1, double beta2, double epsilon, int iteration) { this(gradients, stateV, stateM, stateH, gradients, stateV, stateM, stateH, lr, beta1, beta2, epsilon, iteration); } - public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, @NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) { + public AmsGradUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray stateH, + @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, + @NonNull INDArray updatedStateH, double lr, double beta1, double beta2, double epsilon, int iteration) { addInputArgument(gradients, stateV, stateM, stateH); addOutputArgument(updates, updatedStateV, updatedStateM, updatedStateH); addTArgument(lr, beta1, beta2, epsilon); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java index ad4f374b7..325c85af5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/updaters/NadamUpdater.java @@ -30,11 +30,14 @@ public class NadamUpdater extends DynamicCustomOp { // } - public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, double lr, + double beta1, double beta2, double epsilon, int iteration) { this(gradients, stateV, stateM, gradients, stateV, stateM, lr, beta1, beta2, epsilon, iteration); } - public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates, @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, double epsilon, int iteration) { + public NadamUpdater(@NonNull INDArray gradients, @NonNull INDArray stateV, @NonNull INDArray stateM, @NonNull INDArray updates, + @NonNull INDArray updatedStateV, @NonNull INDArray updatedStateM, double lr, double beta1, double beta2, + double epsilon, int iteration) { addInputArgument(gradients, stateV, stateM); addOutputArgument(updates, updatedStateV, updatedStateM); addTArgument(lr, beta1, beta2, epsilon); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AMSGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AMSGradUpdater.java index 79907a237..37d1cb01d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AMSGradUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AMSGradUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,14 +20,12 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; import lombok.val; -import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; +import org.nd4j.linalg.api.ops.impl.updaters.AmsGradUpdater; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.AMSGrad; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -103,27 +102,11 @@ public class AMSGradUpdater implements GradientUpdater { double epsilon = config.getEpsilon(); //m_t = b_1 * m_{t-1} + (1-b_1) * g_t eq 1 pg 3 - INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); - m.muli(beta1).addi(oneMinusBeta1Grad); - //v_t = b_2 * v_{t-1} + (1-b_2) * (g_t)^2 eq 1 pg 3 - INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2); - v.muli(beta2).addi(oneMinusBeta2GradSquared); - - double beta1t = FastMath.pow(beta1, iteration + 1); - double beta2t = FastMath.pow(beta2, iteration + 1); - //vHat_t = max(vHat_{t-1}, v_t) - Transforms.max(vHat, v, false); - - double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t); - if (Double.isNaN(alphat) || alphat == 0.0) - alphat = epsilon; - //gradient array contains: sqrt(vHat) + eps - Nd4j.getExecutioner().exec(new Sqrt(vHat, gradient)).addi(epsilon); - //gradient = alphat * m_t / (sqrt(vHat) + eps) - gradient.rdivi(m).muli(alphat); + + Nd4j.exec(new AmsGradUpdater(gradient, v, m, vHat, learningRate, beta1, beta2, epsilon, iteration)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java index ced2a8c84..6aa7d7ab4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaDeltaUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,9 +20,9 @@ package org.nd4j.linalg.learning; import lombok.Data; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.AdaDelta; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -104,16 +105,11 @@ public class AdaDeltaUpdater implements GradientUpdater { //Line 4 of Algorithm 1: https://arxiv.org/pdf/1212.5701v1.pdf //E[g^2]_t = rho * E[g^2]_{t-1} + (1-rho)*g^2_t - msg.muli(rho).addi(gradient.mul(gradient).muli(1 - rho)); - //Calculate update: //dX = - g * RMS[delta x]_{t-1} / RMS[g]_t //Note: negative is applied in the DL4J step function: params -= update rather than params += update - INDArray rmsdx_t1 = Transforms.sqrt(msdx.add(epsilon), false); - INDArray rmsg_t = Transforms.sqrt(msg.add(epsilon), false); - INDArray update = gradient.muli(rmsdx_t1.divi(rmsg_t)); - //Accumulate gradients: E[delta x^2]_t = rho * E[delta x^2]_{t-1} + (1-rho)* (delta x_t)^2 - msdx.muli(rho).addi(update.mul(update).muli(1 - rho)); + + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaDeltaUpdater(gradient, msg, msdx, rho, epsilon)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaGradUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaGradUpdater.java index 09a530a51..ad355d1cc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaGradUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaGradUpdater.java @@ -18,16 +18,14 @@ package org.nd4j.linalg.learning; import lombok.Data; -import lombok.EqualsAndHashCode; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.AdaGrad; import java.util.Collections; import java.util.Map; -import static org.nd4j.linalg.ops.transforms.Transforms.sqrt; - /** * Vectorized Learning Rate used per Connection Weight @@ -98,10 +96,6 @@ public class AdaGradUpdater implements GradientUpdater { double learningRate = config.getLearningRate(iteration, epoch); double epsilon = config.getEpsilon(); - historicalGradient.addi(gradient.mul(gradient)); - - INDArray sqrtHistory = sqrt(historicalGradient.dup(gradientReshapeOrder), false).addi(epsilon); - // lr * gradient / (sqrt(sumSquaredGradients) + epsilon) - gradient.muli(sqrtHistory.rdivi(learningRate)); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaGradUpdater(gradient, historicalGradient, learningRate, epsilon)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java index 20a908f1e..06fbde54d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdaMaxUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,14 +19,11 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; -import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Max; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.AdaMax; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -99,22 +97,13 @@ public class AdaMaxUpdater implements GradientUpdater { throw new IllegalStateException("Updater has not been initialized with view state"); //m = B_1 * m + (1-B_1)*grad - m.muli(config.getBeta1()).addi(gradient.mul(1 - config.getBeta1())); - //u = max(B_2 * u, |grad|) - u.muli(config.getBeta2()); - Transforms.abs(gradient, false); //In-place should be OK here, original gradient values aren't used again later - Nd4j.getExecutioner().exec(new Max(u, gradient, u)); - double beta1t = FastMath.pow(config.getBeta1(), iteration + 1); + double lr = config.getLearningRate(iteration, epoch); + double b1 = config.getBeta1(); + double b2 = config.getBeta2(); + double eps = config.getEpsilon(); - double learningRate = config.getLearningRate(iteration, epoch); - double alphat = learningRate / (1.0 - beta1t); - if (Double.isNaN(alphat) || Double.isInfinite(alphat) || alphat == 0.0) { - alphat = config.getEpsilon(); - } - - u.addi(1e-32); // prevent NaNs in params - gradient.assign(m).muli(alphat).divi(u); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater(gradient, u, m, lr, b1, b2, eps, iteration)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java index e68af09f7..e72bfe5a4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/AdamUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -18,12 +19,11 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; -import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.Adam; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.HashMap; import java.util.Map; @@ -102,20 +102,6 @@ public class AdamUpdater implements GradientUpdater { double learningRate = config.getLearningRate(iteration, epoch); double epsilon = config.getEpsilon(); - INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); - m.muli(beta1).addi(oneMinusBeta1Grad); - - INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1 - beta2); - v.muli(beta2).addi(oneMinusBeta2GradSquared); - - double beta1t = FastMath.pow(beta1, iteration + 1); - double beta2t = FastMath.pow(beta2, iteration + 1); - - double alphat = learningRate * FastMath.sqrt(1 - beta2t) / (1 - beta1t); - if (Double.isNaN(alphat) || alphat == 0.0) - alphat = epsilon; - INDArray sqrtV = Transforms.sqrt(v.dup(gradientReshapeOrder), false).addi(epsilon); - - gradient.assign(m).muli(alphat).divi(sqrtV); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdamUpdater(gradient, v, m, learningRate, beta1, beta2, epsilon, iteration)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java index 18a29cc25..6432e288e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NadamUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -21,6 +22,7 @@ import lombok.NonNull; import org.apache.commons.math3.util.FastMath; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.learning.config.Nadam; import org.nd4j.linalg.ops.transforms.Transforms; @@ -101,21 +103,6 @@ public class NadamUpdater implements GradientUpdater { double learningRate = config.getLearningRate(iteration, epoch); double epsilon = config.getEpsilon(); - INDArray oneMinusBeta1Grad = gradient.mul(1.0 - beta1); - m.muli(beta1).addi(oneMinusBeta1Grad); - - INDArray oneMinusBeta2GradSquared = gradient.mul(gradient).muli(1.0 - beta2); - v.muli(beta2).addi(oneMinusBeta2GradSquared); - - double beta1t = FastMath.pow(beta1, iteration + 1); - - INDArray biasCorrectedEstimateOfMomentum = m.mul(beta1).divi(1.0 - beta1t); - INDArray secondTerm = oneMinusBeta1Grad.divi(1 - beta1t); - - INDArray alphat = biasCorrectedEstimateOfMomentum.add(secondTerm).muli(learningRate); - - INDArray sqrtV = Transforms.sqrt(v.dup(gradientReshapeOrder), false).addi(epsilon); - - gradient.assign(alphat).divi(sqrtV); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NadamUpdater(gradient, v, m, learningRate, beta1, beta2, epsilon, iteration)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java index 64a9a6f87..2a18b78d8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/NesterovsUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,7 +20,6 @@ package org.nd4j.linalg.learning; import lombok.Data; import lombok.NonNull; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Nesterovs; @@ -95,16 +95,8 @@ public class NesterovsUpdater implements GradientUpdater { //DL4J default is negative step function thus we flipped the signs: // x += mu * v_prev + (-1 - mu) * v //i.e., we do params -= updatedGradient, not params += updatedGradient - //v = mu * v - lr * gradient - INDArray vPrev = v.dup(gradientReshapeOrder); - v.muli(momentum).subi(gradient.dup(gradientReshapeOrder).muli(learningRate)); //Modify state array in-place - /* - Next line is equivalent to: - INDArray ret = vPrev.muli(momentum).addi(v.mul(-momentum - 1)); - gradient.assign(ret); - */ - Nd4j.getExecutioner().exec(new AddOp(vPrev.muli(momentum), v.mul(-momentum - 1), gradient)); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.NesterovsUpdater(gradient, v, learningRate, momentum)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/RmsPropUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/RmsPropUpdater.java index e2d68c4bf..866f9ce0d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/RmsPropUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/RmsPropUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -20,8 +21,8 @@ import lombok.Data; import lombok.NonNull; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.RmsProp; -import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Collections; import java.util.Map; @@ -85,8 +86,7 @@ public class RmsPropUpdater implements GradientUpdater { double rmsDecay = config.getRmsDecay(); double epsilon = config.getEpsilon(); - lastGradient.muli(rmsDecay).addi(gradient.mul(gradient).muli(1 - rmsDecay)); // lr * gradient / (sqrt(cache) + 1e-8) - gradient.muli(learningRate).divi(Transforms.sqrt(lastGradient.dup(gradientReshapeOrder), false).addi(epsilon)); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.RmsPropUpdater(gradient, lastGradient, learningRate, rmsDecay, epsilon)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/SgdUpdater.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/SgdUpdater.java index 1eca487c1..a2d0b0214 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/SgdUpdater.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/learning/SgdUpdater.java @@ -1,5 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -19,6 +20,7 @@ package org.nd4j.linalg.learning; import lombok.Data; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.learning.config.Sgd; import java.util.Collections; @@ -56,6 +58,6 @@ public class SgdUpdater implements GradientUpdater { @Override public void applyUpdater(INDArray gradient, int iteration, int epoch) { double lr = config.getLearningRate(iteration, epoch); - gradient.muli(lr); + Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.SgdUpdater(gradient, lr)); } } From 5fbb04531d0a049180ec05c778f07e4a95cca1ca Mon Sep 17 00:00:00 2001 From: Andrii T <39699084+atuzhykov@users.noreply.github.com> Date: Fri, 17 Apr 2020 08:16:14 +0300 Subject: [PATCH 08/18] At cpp ops (#378) * crelu op added * crelu op added Signed-off-by: Andrii Tuzhykov * minor fixes Signed-off-by: Andrii Tuzhykov * crelu(bp)+transformOpValidation op Signed-off-by: Andrii Tuzhykov * added ClipByAvgNorm and DepthwiseConv2DBp Signed-off-by: Andrii Tuzhykov * ClipByAvgNorm passes forward check Signed-off-by: Andrii Tuzhykov * EmbeddingLookup draft Signed-off-by: Andrii Tuzhykov * DepthwiseConv2DB gradient check Signed-off-by: Andrii Tuzhykov * EmbeddingLookup and DepthwiseConv2dBp finished + tests added Signed-off-by: Andrii Tuzhykov * ImageResize draft Signed-off-by: Andrii Tuzhykov * DepthwiseConv2DB gradient check Signed-off-by: Andrii Tuzhykov * ImageResize passed tests except helper::resizeFunctor:Non implemented Signed-off-by: Andrii Tuzhykov * replaced ImageResizeMethods enum by codegen Signed-off-by: Andrii Tuzhykov * minor fixes Signed-off-by: Andrii Tuzhykov * polished checkpoint (OPValidationSuite passed and mvn install build succesfull after codegen) Signed-off-by: Andrii Tuzhykov * manually merged LSTMLayerTestCases from master Signed-off-by: Andrii Tuzhykov Signed-off-by: Andrii Tuzhykov * MaximumBp added and tested Signed-off-by: Andrii Tuzhykov * MergeAddBp draft Signed-off-by: Andrii Tuzhykov * MergeMaxBp and MergeAvgBP added and tests passed Signed-off-by: Andrii Tuzhykov * minor fix * draft LSTMLayerBp (big relative layer in gradient check) * LSTMLayerBp check Signed-off-by: Andrii Tuzhykov * LSTMLayerBp check v2 Signed-off-by: Andrii Tuzhykov * requested changes (test passes) Signed-off-by: Andrii Tuzhykov * LSTMLayer testcases passed gradientcheck Signed-off-by: Andrii Tuzhykov * small LSTMLayer testcase1 improvement (cLast, yLast) Signed-off-by: Andrii Tuzhykov * Warnings issue solved Signed-off-by: Andrii Tuzhykov * Fixes for MKLDNN LSTM layer helper Signed-off-by: Alex Black * stable version Signed-off-by: Andrii Tuzhykov Co-authored-by: raver119 Co-authored-by: Alex Black --- .../declarable/platform/mkldnn/lstmLayer.cpp | 6 +- .../functions/DifferentialFunction.java | 1 + .../nd4j/autodiff/samediff/ops/SDImage.java | 93 +++ .../nd4j/autodiff/samediff/ops/SDMath.java | 62 ++ .../org/nd4j/autodiff/samediff/ops/SDNN.java | 24 + .../org/nd4j/enums/ImageResizeMethod.java | 43 ++ .../java/org/nd4j/enums/PartitionMode.java | 27 + .../converters/ImportClassMapping.java | 12 + .../api/ops/impl/image/ImageResize.java | 67 +++ .../layers/convolution/DepthwiseConv2D.java | 14 +- .../layers/convolution/DepthwiseConv2DBp.java | 150 +++++ .../ops/impl/layers/recurrent/LSTMLayer.java | 31 +- .../impl/layers/recurrent/LSTMLayerBp.java | 176 ++++++ .../recurrent/config/LSTMLayerConfig.java | 17 +- .../linalg/api/ops/impl/shape/MergeAvg.java | 14 +- .../linalg/api/ops/impl/shape/MergeMax.java | 16 +- .../api/ops/impl/shape/bp/MergeAvgBp.java | 57 ++ .../api/ops/impl/shape/bp/MergeMaxBp.java | 56 ++ .../impl/shape/tensorops/EmbeddingLookup.java | 71 +++ .../impl/transforms/clip/ClipByAvgNorm.java | 71 +++ .../api/ops/impl/transforms/custom/CReLU.java | 65 ++ .../ops/impl/transforms/custom/CReluBp.java | 59 ++ .../api/ops/impl/transforms/custom/Max.java | 7 +- .../ops/impl/transforms/custom/MaximumBp.java | 48 ++ .../pairwise/arithmetic/MergeAddOp.java | 12 +- .../pairwise/arithmetic/bp/MergeAddBp.java | 54 ++ .../org/nd4j/linalg/factory/ops/NDImage.java | 44 ++ .../org/nd4j/linalg/factory/ops/NDMath.java | 29 + .../org/nd4j/linalg/factory/ops/NDNN.java | 11 + .../opvalidation/LayerOpValidation.java | 162 +++-- .../opvalidation/TransformOpValidation.java | 553 +++++++++++++----- 31 files changed, 1794 insertions(+), 258 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PartitionMode.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ImageResize.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2DBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayerBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeAvgBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeMaxBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/EmbeddingLookup.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByAvgNorm.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReLU.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReluBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index d09a40120..6763d1403 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -369,6 +369,7 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); + REQUIRE_TRUE(hasInitH == hasInitC, 0, "LSTM_LAYER_MKLDNN operation: either both of or neither of initial C and initial H must be provided"); count = 0; auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output @@ -498,7 +499,7 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { DataType WrType = Wr->dataType(); DataType bType = b != nullptr ? b->dataType() : (xType == DataType::HALF ? xType : DataType::FLOAT32); DataType hIType = hI != nullptr ? hI->dataType() : xType; - DataType cIType = cI != nullptr ? hI->dataType() : xType; + DataType cIType = cI != nullptr ? cI->dataType() : xType; DataType hType = h != nullptr ? h->dataType() : xType; DataType hLType = hL != nullptr ? hL->dataType() : xType; DataType cLType = cL != nullptr ? cL->dataType() : xType; @@ -509,7 +510,8 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { && !hasSeqLen //Sequence length array not supported in MKL DNN && dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn] && directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat - && retLastH == retLastC; //Return both lastH and lastC, or return neither (not just 1 or other) + && retLastH == retLastC //Return both lastH and lastC, or return neither (not just 1 or other) + && hasInitH == hasInitC; //Need both or neither initial H and C return block.isUseMKLDNN() && featuresSupported && ( (xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) || diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 94bda0b78..8a629bc66 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -153,6 +153,7 @@ public abstract class DifferentialFunction { public Map propertiesForFunction() { Map fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this); Map ret = new LinkedHashMap<>(); + Preconditions.checkNotNull(fields, "DifferentialFunctionClassHolder returned null fields for %s - op has not been added to ImportClassMapping?", getClass()); for(val entry : fields.entrySet()) { try { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index 70940863a..a58d4d180 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -24,6 +24,7 @@ import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.ImageResizeMethod; public class SDImage extends SDOps { public SDImage(SameDiff sameDiff) { @@ -254,6 +255,98 @@ public class SDImage extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Resize images to size using the specified method.
+ * + * @param input 4D image [NHWC] (NUMERIC type) + * @param size new height and width (INT type) + * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. + * @param antialis Whether to use an anti-aliasing filter when downsampling an image + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public SDVariable imageResize(SDVariable input, SDVariable size, boolean preserveAspectRatio, + boolean antialis, ImageResizeMethod ImageResizeMethod) { + SDValidation.validateNumerical("imageResize", "input", input); + SDValidation.validateInteger("imageResize", "size", size); + return new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, preserveAspectRatio, antialis, ImageResizeMethod).outputVariable(); + } + + /** + * Resize images to size using the specified method.
+ * + * @param name name May be null. Name for the output variable + * @param input 4D image [NHWC] (NUMERIC type) + * @param size new height and width (INT type) + * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. + * @param antialis Whether to use an anti-aliasing filter when downsampling an image + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public SDVariable imageResize(String name, SDVariable input, SDVariable size, + boolean preserveAspectRatio, boolean antialis, ImageResizeMethod ImageResizeMethod) { + SDValidation.validateNumerical("imageResize", "input", input); + SDValidation.validateInteger("imageResize", "size", size); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, preserveAspectRatio, antialis, ImageResizeMethod).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Resize images to size using the specified method.
+ * + * @param input 4D image [NHWC] (NUMERIC type) + * @param size new height and width (INT type) + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public SDVariable imageResize(SDVariable input, SDVariable size, + ImageResizeMethod ImageResizeMethod) { + SDValidation.validateNumerical("imageResize", "input", input); + SDValidation.validateInteger("imageResize", "size", size); + return new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, false, false, ImageResizeMethod).outputVariable(); + } + + /** + * Resize images to size using the specified method.
+ * + * @param name name May be null. Name for the output variable + * @param input 4D image [NHWC] (NUMERIC type) + * @param size new height and width (INT type) + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public SDVariable imageResize(String name, SDVariable input, SDVariable size, + ImageResizeMethod ImageResizeMethod) { + SDValidation.validateNumerical("imageResize", "input", input); + SDValidation.validateInteger("imageResize", "size", size); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, false, false, ImageResizeMethod).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Greedily selects a subset of bounding boxes in descending order of score
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index ead137a57..1f89ba1d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -24,6 +24,7 @@ import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PartitionMode; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.indexing.conditions.Condition; @@ -32,6 +33,67 @@ public class SDMath extends SDOps { super(sameDiff); } + /** + * Clips tensor values to a maximum average L2-norm.
+ * + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping + * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByAvgNorm(SDVariable x, double clipValue, int... dimensions) { + SDValidation.validateNumerical("ClipByAvgNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable(); + } + + /** + * Clips tensor values to a maximum average L2-norm.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping + * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByAvgNorm(String name, SDVariable x, double clipValue, int... dimensions) { + SDValidation.validateNumerical("ClipByAvgNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Looks up ids in a list of embedding tensors.
+ * + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' + * @return output Shifted output (NUMERIC type) + */ + public SDVariable embeddingLookup(SDVariable x, SDVariable indices, PartitionMode PartitionMode) { + SDValidation.validateNumerical("EmbeddingLookup", "x", x); + SDValidation.validateInteger("EmbeddingLookup", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); + } + + /** + * Looks up ids in a list of embedding tensors.
+ * + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' + * @return output Shifted output (NUMERIC type) + */ + public SDVariable embeddingLookup(String name, SDVariable x, SDVariable indices, + PartitionMode PartitionMode) { + SDValidation.validateNumerical("EmbeddingLookup", "x", x); + SDValidation.validateInteger("EmbeddingLookup", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Elementwise absolute value operation: out = abs(x)
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 7b18c3614..9633a0186 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -30,6 +30,30 @@ public class SDNN extends SDOps { super(sameDiff); } + /** + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cReLU(SDVariable x) { + SDValidation.validateNumerical("CReLU", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable(); + } + + /** + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cReLU(String name, SDVariable x) { + SDValidation.validateNumerical("CReLU", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Neural network batch normalization operation.
* For details, see https://arxiv.org/abs/1502.03167
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java new file mode 100644 index 000000000..42043dad7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java @@ -0,0 +1,43 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. */ +public enum ImageResizeMethod { + ResizeBilinear, + + ResizeBicubic, + + ResizeNearest, + + ResizeGaussian, + + ResizeLanczos5, + + ResizeMitchelcubic, + + ResizeArea +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PartitionMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PartitionMode.java new file mode 100644 index 000000000..565ffd792 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PartitionMode.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * partition_mode == 0 - i.e. 'mod' , 1 - 'div' */ +public enum PartitionMode { + MOD, + + DIV +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 043a16e87..6af2d462a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -93,6 +93,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.grid.FreeGridOp.class, org.nd4j.linalg.api.ops.impl.image.CropAndResize.class, org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches.class, + org.nd4j.linalg.api.ops.impl.image.ImageResize.class, org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression.class, org.nd4j.linalg.api.ops.impl.image.NonMaxSuppressionV3.class, org.nd4j.linalg.api.ops.impl.image.ResizeBilinear.class, @@ -127,6 +128,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class, org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class, + org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2DBp.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp.class, org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization.class, @@ -146,6 +148,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer.class, + org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayerBp.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell.class, @@ -322,9 +325,12 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.shape.Unstack.class, org.nd4j.linalg.api.ops.impl.shape.ZerosLike.class, org.nd4j.linalg.api.ops.impl.shape.bp.ConcatBp.class, + org.nd4j.linalg.api.ops.impl.shape.bp.MergeMaxBp.class, + org.nd4j.linalg.api.ops.impl.shape.bp.MergeAvgBp.class, org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp.class, org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp.class, org.nd4j.linalg.api.ops.impl.shape.bp.TileBp.class, + org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup.class, org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray.class, org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat.class, org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather.class, @@ -354,6 +360,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf.class, org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN.class, org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform.class, + org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm.class, org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm.class, org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNormBp.class, org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue.class, @@ -365,6 +372,8 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.CReluBp.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class, @@ -406,6 +415,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse.class, org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Max.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Min.class, org.nd4j.linalg.api.ops.impl.transforms.custom.MirrorPad.class, org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention.class, @@ -492,11 +502,13 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.ModBpOp.class, + org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MergeAddBp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SquaredDifferenceBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp.class, + org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Not.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ImageResize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ImageResize.java new file mode 100644 index 000000000..4bdca62a6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ImageResize.java @@ -0,0 +1,67 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.image; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.enums.ImageResizeMethod; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class ImageResize extends DynamicCustomOp { + + + + @Override + public String opName() { + return "image_resize"; + } + + + public ImageResize(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable size, boolean preserveAspectRatio, boolean antialias, ImageResizeMethod method) { + super("image_resize", sameDiff, new SDVariable[]{in, size}); + addBArgument(preserveAspectRatio, antialias); + addIArgument(method.ordinal()); + } + + public ImageResize(@NonNull INDArray in, @NonNull INDArray size, boolean preserveAspectRatio, boolean antialias, ImageResizeMethod method) { + super("image_resize", new INDArray[]{in, size}, null); + Preconditions.checkArgument(in.rank()==4,"expected input message in NHWC format i.e [batchSize, height, width, channels]"); + addBArgument(preserveAspectRatio, antialias); + addIArgument(method.ordinal()); + } + + + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java index afb51af58..798b544b8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java @@ -56,9 +56,11 @@ public class DepthwiseConv2D extends DynamicCustomOp { protected Conv2DConfig config; + public DepthwiseConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig); + } @Builder(builderMethodName = "sameDiffBuilder") @@ -71,14 +73,14 @@ public class DepthwiseConv2D extends DynamicCustomOp { addArgs(); } - public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ + public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config) { super(inputs, outputs); this.config = config; addArgs(); } - public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){ + public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config) { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } @@ -127,7 +129,7 @@ public class DepthwiseConv2D extends DynamicCustomOp { @Override public Map propertiesForFunction() { - if(config == null && !iArguments.isEmpty()){ + if (config == null && !iArguments.isEmpty()) { config = Conv2DConfig.builder() .kH(iArguments.get(0)) .kW(iArguments.get(1)) @@ -308,7 +310,9 @@ public class DepthwiseConv2D extends DynamicCustomOp { @Override public List doDiff(List f1) { - throw new UnsupportedOperationException("Not implemented yet"); + SDVariable bias = args().length==2 ? null : arg(2); + return Arrays.asList(new DepthwiseConv2DBp(sameDiff, arg(0), arg(1), bias, f1.get(0), this.config).outputVariables()); + } @@ -323,7 +327,7 @@ public class DepthwiseConv2D extends DynamicCustomOp { } @Override - public List calculateOutputDataTypes(List inputDataTypes){ + public List calculateOutputDataTypes(List inputDataTypes) { int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2DBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2DBp.java new file mode 100644 index 000000000..482944fe2 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2DBp.java @@ -0,0 +1,150 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.convolution; + +import lombok.*; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.converters.DifferentialFunctionClassHolder; +import org.nd4j.imports.descriptors.properties.AttributeAdapter; +import org.nd4j.imports.descriptors.properties.PropertyMapping; +import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter; +import org.nd4j.imports.descriptors.properties.adapters.NDArrayShapeAdapter; +import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater; +import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.util.ArrayUtil; + +import java.lang.reflect.Field; +import java.util.*; + + +/** + * Backpropagation for Depthwise Conv2D operation + */ +@Slf4j +@Getter +@NoArgsConstructor +public class DepthwiseConv2DBp extends DynamicCustomOp { + + protected Conv2DConfig config; + + + public DepthwiseConv2DBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull SDVariable gradO, @NonNull Conv2DConfig config){ + super(sameDiff, wrapFilterNull(input, weights, bias, gradO)); + this.config = config; + addArgs(); + + } + + public DepthwiseConv2DBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull SDVariable gradO, @NonNull Conv2DConfig config){ + super(sameDiff, wrapFilterNull(input, weights, gradO)); + this.config = config; + addArgs(); + + } + + + @Override + public long[] iArgs() { + if (iArguments.size() == 0) + addArgs(); + + return super.iArgs(); + } + + protected void addArgs() { + addIArgument(config.getKH(), + config.getKW(), + config.getSH(), + config.getSW(), + config.getPH(), + config.getPW(), + config.getDH(), + config.getDW(), + ArrayUtil.fromBoolean(config.isSameMode()), + config.getDataFormat().equalsIgnoreCase(Conv2DConfig.NCHW) ? 0 : 1); + + } + + @Override + public Object getValue(Field property) { + if (config == null) { + config = Conv2DConfig.builder().build(); + } + + try { + val t = config.getValue(property); + return t; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public Map propertiesForFunction() { + if (config == null && !iArguments.isEmpty()) { + config = Conv2DConfig.builder() + .kH(iArguments.get(0)) + .kW(iArguments.get(1)) + .sH(iArguments.get(2)) + .sW(iArguments.get(3)) + .pH(iArguments.get(4)) + .pW(iArguments.get(5)) + .dH(iArguments.get(6)) + .dW(iArguments.get(7)) + .isSameMode(iArguments.get(8) == 1) + .dataFormat(iArguments.get(9) == 1 ? Conv2DConfig.NHWC : Conv2DConfig.NCHW) + .build(); + } + return config.toProperties(); + } + + + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName() { + return "config"; + } + + @Override + public String opName() { + return "depthwise_conv2d_bp"; + } + + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + int n = args().length; + List list = new ArrayList(); + for(int i=0;i * 2: cell state at last step cL - same shape as in hL
*/ +@NoArgsConstructor public class LSTMLayer extends DynamicCustomOp { @Getter @@ -68,14 +71,18 @@ public class LSTMLayer extends DynamicCustomOp { @Getter private LSTMLayerWeights weights; + private SDVariable cLast; + private SDVariable yLast; + private SDVariable maxTSLength; - public LSTMLayer() { - } public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights weights, LSTMLayerConfig configuration) { super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast)); this.configuration = configuration; this.weights = weights; + this.cLast = cLast; + this.yLast = yLast; + this.maxTSLength = maxTSLength; addIArgument(iArgs()); addTArgument(tArgs()); addBArgument(bArgs(weights, maxTSLength, yLast, cLast)); @@ -124,7 +131,13 @@ public class LSTMLayer extends DynamicCustomOp { @Override public List doDiff(List grads) { - throw new UnsupportedOperationException("Not yet implemented"); + int i=0; + SDVariable grad0 = this.configuration.isRetFullSequence() ? grads.get(i++): null; + SDVariable grad1 = this.configuration.isRetLastH() ? grads.get(i++): null; + SDVariable grad2 = this.configuration.isRetLastC() ? grads.get(i++): null; + + return Arrays.asList(new LSTMLayerBp(sameDiff, arg(0), this.cLast, this.yLast, this.maxTSLength, + this.weights, this.configuration, grad0, grad1,grad2).outputVariables()); } @@ -155,7 +168,7 @@ public class LSTMLayer extends DynamicCustomOp { } - public boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) { + protected boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) { return new boolean[]{ weights.hasBias(), // hasBiases: B_ARG(0) maxTSLength != null, // hasSeqLen: B_ARG(1) @@ -169,6 +182,16 @@ public class LSTMLayer extends DynamicCustomOp { } + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName() { + return "configuration"; + } + @Override public int getNumOutputs(){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayerBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayerBp.java new file mode 100644 index 000000000..d6ffcd6e5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayerBp.java @@ -0,0 +1,176 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; +import org.nd4j.shade.guava.primitives.Booleans; + +import javax.xml.crypto.Data; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + + +/** + * LSTM layer backpropagation + */ +@NoArgsConstructor +public class LSTMLayerBp extends DynamicCustomOp { + + @Getter + private LSTMLayerConfig configuration; + + @Getter + private LSTMLayerWeights weights; + + private SDVariable cLast; + private SDVariable yLast; + private SDVariable maxTSLength; + + + public LSTMLayerBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, @NonNull LSTMLayerWeights weights, @NonNull LSTMLayerConfig configuration, + SDVariable dLdh, SDVariable dLdhL, SDVariable dLdcL) { + super("lstmLayer_bp", sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getRWeights(), weights.getBias(), + maxTSLength, yLast, cLast, weights.getPeepholeWeights(), dLdh, dLdhL, dLdcL)); + this.configuration = configuration; + this.weights = weights; + this.cLast = cLast; + this.yLast = yLast; + this.maxTSLength = maxTSLength; + addIArgument(iArgs()); + addTArgument(tArgs()); + addBArgument(bArgs(weights, maxTSLength, yLast, cLast)); + + + Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), + "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them"); + + + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + + DataType dt = inputDataTypes.get(1); + Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt); + ArrayList list = new ArrayList<>(); + list.add(dt); // dLdx + list.add(dt); // dLdWx + list.add(dt); // dLdWr + + if (this.weights.hasBias()) { + list.add(dt); + } // dLdb + + if (this.maxTSLength != null) { + list.add(dt); + } // dLdSl + if (this.yLast != null) { + list.add(dt); + } //dLdhI + if (this.cLast != null) { + list.add(dt); + } // dLdcI + if (this.weights.hasPH()) { + list.add(dt); + } // dLdWp + + return list; + } + + + @Override + public String opName() { + return "lstmLayer_bp"; + } + + @Override + public Map propertiesForFunction() { + return configuration.toProperties(true, true); + } + + + public long[] iArgs() { + return new long[]{ + configuration.getLstmdataformat().ordinal(),// INT_ARG(0) + configuration.getDirectionMode().ordinal(), // INT_ARG(1) + configuration.getGateAct().ordinal(), // INT_ARG(2) + configuration.getOutAct().ordinal(), // INT_ARG(3) + configuration.getCellAct().ordinal() // INT_ARG(4) + + }; + } + + public double[] tArgs() { + return new double[]{this.configuration.getCellClip()}; // T_ARG(0) + } + + + protected boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) { + return new boolean[]{ + weights.hasBias(), // hasBiases: B_ARG(0) + maxTSLength != null, // hasSeqLen: B_ARG(1) + yLast != null, // hasInitH: B_ARG(2) + cLast != null, // hasInitC: B_ARG(3) + weights.hasPH(), // hasPH: B_ARG(4) + configuration.isRetFullSequence(), //retFullSequence: B_ARG(5) + configuration.isRetLastH(), // retLastH: B_ARG(6) + configuration.isRetLastC() // retLastC: B_ARG(7) + }; + + } + + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName() { + return "configuration"; + } + + + @Override + public int getNumOutputs() { + + return Booleans.countTrue( + true, + true, + true, + weights.hasBias(), + this.maxTSLength != null, + this.yLast != null, + this.cLast != null, + weights.hasPH() + ); + } + + +} + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java index 9901213da..226150e8b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java @@ -15,8 +15,10 @@ ******************************************************************************/ package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; +import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; +import lombok.NoArgsConstructor; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; @@ -26,9 +28,10 @@ import java.util.Map; @Builder @Data +@AllArgsConstructor +@NoArgsConstructor public class LSTMLayerConfig { - /** * notations
* for unidirectional: @@ -90,23 +93,23 @@ public class LSTMLayerConfig { * Cell clipping value, if it = 0 then do not apply clipping */ @Builder.Default - private double cellClip; //T_ARG(0) + private double cellClip = 0; //T_ARG(0) public Map toProperties(boolean includeLSTMDataFormat, boolean includeLSTMDirectionMode) { Map ret = new LinkedHashMap<>(); - ret.put("gateAct", gateAct.ordinal()); - ret.put("outAct", outAct.ordinal()); - ret.put("cellAct", cellAct.ordinal()); + ret.put("gateAct", gateAct.toString()); + ret.put("outAct", outAct.toString()); + ret.put("cellAct", cellAct.toString()); ret.put("retFullSequence", retFullSequence); ret.put("retLastH", retLastH); ret.put("retLastC", retLastC); ret.put("cellClip", cellClip); if (includeLSTMDataFormat) - ret.put("LSTMDataFormat", lstmdataformat.ordinal()); + ret.put("lstmdataformat", lstmdataformat.toString()); if (includeLSTMDirectionMode) - ret.put("LSTMDirectionMode", directionMode.ordinal()); + ret.put("directionMode", directionMode.toString()); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java index b63052eb5..3fc734b6b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java @@ -24,15 +24,13 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.MergeAvgBp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import java.util.*; @Slf4j public class MergeAvg extends DynamicCustomOp { @@ -74,12 +72,8 @@ public class MergeAvg extends DynamicCustomOp { @Override public List doDiff(List i_v) { - int nArgs = args().length; - SDVariable gradient = sameDiff.setupFunction(i_v.get(0)).div(nArgs); - List ret = new ArrayList<>(); - for (int i = 0; i < args().length; i++) - ret.add(gradient); - return ret; + return Arrays.asList(new MergeAvgBp(sameDiff, args(), i_v.get(0)).outputVariables()); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java index 2b954e8b7..4e41344fb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java @@ -24,14 +24,12 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.MergeMaxBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import java.util.*; @Slf4j public class MergeMax extends DynamicCustomOp { @@ -71,14 +69,8 @@ public class MergeMax extends DynamicCustomOp { @Override public List doDiff(List i_v) { - SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - List ret = new ArrayList<>(); - SDVariable out = outputVariable(); - for (int i = 0; i < args().length; i++){ - SDVariable isMax = out.eq(arg(i)).castTo(arg(i).dataType()); - ret.add(isMax.mul(gradient)); - } - return ret; + return Arrays.asList(new MergeMaxBp(sameDiff, args(), i_v.get(0)).outputVariables()); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeAvgBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeAvgBp.java new file mode 100644 index 000000000..54d39ce89 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeAvgBp.java @@ -0,0 +1,57 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.shape.bp; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.List; + + +@NoArgsConstructor +public class MergeAvgBp extends DynamicCustomOp { + + public MergeAvgBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) { + super("mergeavg_bp", sameDiff, ArrayUtils.add(inputs, gradO)); + } + + @Override + public String opName() { + return "mergeavg_bp"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + ArrayList list = new ArrayList(); + for (int i = 0; i < args().length - 1; i++) { + list.add(inputDataTypes.get(0)); + } + return list; + + } + + @Override + public int getNumOutputs() { + return args().length - 1; + } + +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeMaxBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeMaxBp.java new file mode 100644 index 000000000..792036b76 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeMaxBp.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.shape.bp; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.List; + + +@NoArgsConstructor +public class MergeMaxBp extends DynamicCustomOp { + + public MergeMaxBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) { + super("mergemax_bp", sameDiff, ArrayUtils.add(inputs, gradO)); + } + + @Override + public String opName() { + return "mergemax_bp"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + List list = new ArrayList(); + for (int i=0; i< args().length-1;i++){ + list.add(inputDataTypes.get(0)); + } + return list; + + } + + @Override + public int getNumOutputs(){ + return args().length-1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/EmbeddingLookup.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/EmbeddingLookup.java new file mode 100644 index 000000000..e59abc268 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/EmbeddingLookup.java @@ -0,0 +1,71 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.shape.tensorops; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.val; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.enums.PartitionMode; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class EmbeddingLookup extends DynamicCustomOp { + + public EmbeddingLookup(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable indices, PartitionMode partitionMode) { + super("embedding_lookup", sameDiff, new SDVariable[]{in, indices}); + addIArgument(partitionMode.ordinal()); + } + + public EmbeddingLookup(@NonNull INDArray in, @NonNull INDArray indices, PartitionMode partitionMode, INDArray output) { + super("embedding_lookup", new INDArray[]{in, indices}, wrapOrNull(output)); + addIArgument(partitionMode.ordinal()); + + } + + public EmbeddingLookup(@NonNull INDArray in, INDArray output, PartitionMode partitionMode, @NonNull int... indices) { + super("embedding_lookup", new INDArray[]{in, Nd4j.createFromArray(indices)}, wrapOrNull(output)); + addIArgument(partitionMode.ordinal()); + + + } + + @Override + public String opName() { + return "embedding_lookup"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(1).isIntType(), "Input datatype must be integer point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByAvgNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByAvgNorm.java new file mode 100644 index 000000000..a5f53622b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByAvgNorm.java @@ -0,0 +1,71 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.clip; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + + +@NoArgsConstructor +public class ClipByAvgNorm extends DynamicCustomOp { + + private double clipValue; + + + public ClipByAvgNorm(SameDiff sameDiff, SDVariable x, double clipValue, int... dimensions) { + super("clipbyavgnorm", sameDiff, new SDVariable[]{x}); + this.clipValue = clipValue; + this.dimensions = dimensions; + addIArgument(dimensions); + addTArgument(clipValue); + } + + public ClipByAvgNorm(INDArray in, double clipValue, int... dimensions){ + this(in, null, clipValue, dimensions); + } + + public ClipByAvgNorm(INDArray in, INDArray out, double clipValue, int... dimensions){ + super("clipbyavgnorm", new INDArray[]{in}, wrapOrNull(out), Collections.singletonList(clipValue), dimensions); + } + + @Override + public String opName() { + return "clipbyavgnorm"; + } + + + + @Override + public List doDiff(List grad) { + throw new UnsupportedOperationException("Not yet implemented"); } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); + return inputDataTypes; + } + +} + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReLU.java new file mode 100644 index 000000000..d442bc141 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReLU.java @@ -0,0 +1,65 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.base.Preconditions; + +import java.util.Collections; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; + +@NoArgsConstructor +public class CReLU extends DynamicCustomOp { + + + public CReLU(SameDiff sd, SDVariable input) { + super(sd, new SDVariable[]{input}); + } + + public CReLU(@NonNull INDArray input) { + super(new INDArray[]{input}, null); + + } + + + @Override + public String opName() { + return "crelu"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List i_v) { + + return Collections.singletonList(new CReluBp(sameDiff, arg(), i_v.get(0)).outputVariable()); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReluBp.java new file mode 100644 index 000000000..7b96afffd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReluBp.java @@ -0,0 +1,59 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.base.Preconditions; + +import java.util.Collections; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; + + +@NoArgsConstructor +public class CReluBp extends DynamicCustomOp { + + public CReluBp(SameDiff sd, SDVariable input, SDVariable epsilonNext) { + super(sd, new SDVariable[]{input, epsilonNext}); + } + + public CReluBp(@NonNull INDArray input, @NonNull INDArray epsilonNext, INDArray output) { + super(new INDArray[]{input, epsilonNext}, wrapOrNull(output)); + } + + + @Override + public String opName() { + return "crelu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java index e8653d4c0..d2451c0f8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java @@ -73,12 +73,7 @@ public class Max extends BaseDynamicTransformOp { @Override public List doDiff(List f1) { - //TODO Switch to maximum_bp op - https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp - SDVariable max = outputVariables()[0]; - SDVariable eq1 = sameDiff.eq(larg(), max).castTo(arg(0).dataType()); - SDVariable eq2 = sameDiff.eq(rarg(), max).castTo(arg(1).dataType()); - - return Arrays.asList(eq1.mul(f1.get(0)), eq2.mul(f1.get(0))); + return Arrays.asList(new MaximumBp(sameDiff, arg(0), arg(1), f1.get(0)).outputVariables()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java new file mode 100644 index 000000000..92fb3b0eb --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.List; + +@NoArgsConstructor +public class MaximumBp extends DynamicCustomOp { + + public MaximumBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y, @NonNull SDVariable gradO) { + super("maximum_bp",sameDiff, new SDVariable[]{x,y, gradO}); + } + + @Override + public String opName() { + return "maximum_bp"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + List list = new ArrayList(); + list.add(inputDataTypes.get(0)); + list.add(inputDataTypes.get(0)); + return list; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java index fc89333f4..51f2e449d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java @@ -18,14 +18,19 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; import lombok.NoArgsConstructor; import lombok.NonNull; +import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MergeAddBp; +import org.nd4j.linalg.util.ArrayUtil; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -70,11 +75,8 @@ public class MergeAddOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - List ret = new ArrayList<>(); - for (int i = 0; i < args().length; i++) - ret.add(gradient); - return ret; + return Arrays.asList(new MergeAddBp(sameDiff, args(), i_v.get(0)).outputVariables()); + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java new file mode 100644 index 000000000..b0403ecff --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java @@ -0,0 +1,54 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +@NoArgsConstructor +public class MergeAddBp extends DynamicCustomOp { + + public MergeAddBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) { + super("mergeadd_bp", sameDiff, ArrayUtils.add(inputs, gradO)); + } + + @Override + public String opName() { + return "mergeadd_bp"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + ArrayList list = new ArrayList(); + for (int i=0; i< args().length-1;i++){list.add(inputDataTypes.get(0));} + return list; + + } + + @Override + public int getNumOutputs(){ + return args().length-1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java index 859ad43c3..03b9f8571 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops; import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.base.Preconditions; +import org.nd4j.enums.ImageResizeMethod; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.Nd4j; @@ -134,6 +135,49 @@ public class NDImage { return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.HsvToRgb(input))[0]; } + /** + * Resize images to size using the specified method.
+ * + * @param input 4D image [NHWC] (NUMERIC type) + * @param size new height and width (INT type) + * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. + * @param antialis Whether to use an anti-aliasing filter when downsampling an image + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public INDArray imageResize(INDArray input, INDArray size, boolean preserveAspectRatio, + boolean antialis, ImageResizeMethod ImageResizeMethod) { + NDValidation.validateNumerical("imageResize", "input", input); + NDValidation.validateInteger("imageResize", "size", size); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ImageResize(input, size, preserveAspectRatio, antialis, ImageResizeMethod))[0]; + } + + /** + * Resize images to size using the specified method.
+ * + * @param input 4D image [NHWC] (NUMERIC type) + * @param size new height and width (INT type) + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public INDArray imageResize(INDArray input, INDArray size, ImageResizeMethod ImageResizeMethod) { + NDValidation.validateNumerical("imageResize", "input", input); + NDValidation.validateInteger("imageResize", "size", size); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ImageResize(input, size, false, false, ImageResizeMethod))[0]; + } + /** * Greedily selects a subset of bounding boxes in descending order of score
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java index bee0da889..8e8923834 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops; import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PartitionMode; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.NDValidation; @@ -31,6 +32,34 @@ public class NDMath { public NDMath() { } + /** + * Clips tensor values to a maximum average L2-norm.
+ * + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping + * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray clipByAvgNorm(INDArray x, double clipValue, int... dimensions) { + NDValidation.validateNumerical("ClipByAvgNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(x, clipValue, dimensions))[0]; + } + + /** + * Looks up ids in a list of embedding tensors.
+ * + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' + * @return output Shifted output (NUMERIC type) + */ + public INDArray embeddingLookup(INDArray x, INDArray indices, PartitionMode PartitionMode) { + NDValidation.validateNumerical("EmbeddingLookup", "x", x); + NDValidation.validateInteger("EmbeddingLookup", "indices", indices); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(x, indices, PartitionMode))[0]; + } + /** * Elementwise absolute value operation: out = abs(x)
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java index 3f9e1431a..06fb92b64 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java @@ -29,6 +29,17 @@ public class NDNN { public NDNN() { } + /** + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray cReLU(INDArray x) { + NDValidation.validateNumerical("CReLU", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(x))[0]; + } + /** * Neural network batch normalization operation.
* For details, see https://arxiv.org/abs/1502.03167
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 794348369..c83a55d08 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -20,7 +20,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; - import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Ignore; @@ -35,6 +34,7 @@ import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; @@ -265,7 +265,7 @@ public class LayerOpValidation extends BaseOpValidation { msg = "7 - upsampling2d, NCHW, 2x2 - " + Arrays.toString(inSizeNCHW); inSize = inSizeNCHW; in = sd.var("in", inSize); - out = sd.cnn().upsampling2d(in, 2, 2, true); + out = sd.cnn().upsampling2d(in, 2, 2, true); break; default: throw new RuntimeException(); @@ -1469,6 +1469,43 @@ public class LayerOpValidation extends BaseOpValidation { } } + @Test + public void testDepthwiseConv2D(){ + + int bS = 10; + + int kernelHeight = 2; + int kernelWidth = 2; + int strideHeight = 2; + int strideWidth = 2; + int inChannels = 2; + int outChannels = 3; + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var("in", Nd4j.rand(bS, inChannels, 5,5)); + SDVariable weights = sd.var("weights", Nd4j.rand(DataType.DOUBLE, kernelHeight, kernelWidth, inChannels, outChannels)); + SDVariable bias = sd.var("bias", Nd4j.rand(DataType.DOUBLE, inChannels*outChannels)); + Conv2DConfig config = Conv2DConfig.builder() + .kH(kernelHeight) + .kW(kernelWidth) + .sH(strideHeight) + .sW(strideWidth) + .dataFormat("NCHW") + .build(); + + SDVariable out = sd.cnn.depthWiseConv2d(in, weights, bias, config); + SDVariable loss = sd.standardDeviation("loss", out, true); + loss.markAsLoss(); + + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true) + ); + assertNull(err); + + + + } + @Test public void LSTMLayerTestCase1() { @@ -1476,9 +1513,8 @@ public class LayerOpValidation extends BaseOpValidation { int bS = 5; int nIn = 3; int numUnits = 7; - int sL = 10; //small just for test + int sL = 3; //small just for test - SameDiff sd = SameDiff.create(); // notations: // bS - batch size, numExamples @@ -1492,50 +1528,66 @@ public class LayerOpValidation extends BaseOpValidation { // T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) - SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, nIn, sL)); + for (boolean useCLast : new boolean[]{false, true}) { + for (boolean useYLast : new boolean[]{false, true}) { + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var("in", Nd4j.randn(DataType.DOUBLE, bS, nIn, sL)); - SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); - SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); + SDVariable cLast = useCLast ? sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits)) : null; + SDVariable yLast = useYLast ? sd.var("yLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits)) : null; - LSTMLayerConfig c = LSTMLayerConfig.builder() - .lstmdataformat(LSTMDataFormat.NST) - .directionMode(LSTMDirectionMode.FWD) - .gateAct(LSTMActivations.SIGMOID) - .cellAct(LSTMActivations.TANH) - .outAct(LSTMActivations.TANH) - .retFullSequence(true) - .retLastC(true) - .retLastH(true) - .build(); - LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer( - in, cLast, yLast, null, - LSTMLayerWeights.builder() - .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits))) - .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits))) - .peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits))) - .bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits))).build(), - c), c); + LSTMLayerConfig c = LSTMLayerConfig.builder() + .lstmdataformat(LSTMDataFormat.NST) + .directionMode(LSTMDirectionMode.FWD) + .gateAct(LSTMActivations.SIGMOID) + .cellAct(LSTMActivations.TANH) + .outAct(LSTMActivations.TANH) + .retFullSequence(true) + .retLastC(true) + .retLastH(true) + .build(); - long[] out = new long[]{bS, numUnits, sL}; - long[] hL = new long[]{bS, numUnits}; - long[] cL = new long[]{bS, numUnits}; + LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer( + in, cLast, yLast, null, + LSTMLayerWeights.builder() + .weights(sd.var("weights", Nd4j.randn(DataType.DOUBLE, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.randn(DataType.DOUBLE, numUnits, 4 * numUnits))) + .peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.randn(DataType.DOUBLE, 3 * numUnits))) + .bias(sd.var("bias", Nd4j.rand(DataType.DOUBLE, 4 * numUnits))).build(), + c), c); - assertArrayEquals(out, outputs.getOutput().eval().shape()); - assertArrayEquals(hL, outputs.getLastTimeStepOutput().eval().shape()); - assertArrayEquals(cL, outputs.getLastCellStateOutput().eval().shape()); + long[] out = new long[]{bS, numUnits, sL}; + long[] hL = new long[]{bS, numUnits}; + long[] cL = new long[]{bS, numUnits}; + + assertArrayEquals(out, outputs.getOutput().eval().shape()); + assertArrayEquals(hL, outputs.getLastOutput().eval().shape()); + assertArrayEquals(cL, outputs.getLastState().eval().shape()); + + sd.setLossVariables(outputs.getOutput(), outputs.getLastTimeStepOutput(), outputs.getTimeSeriesOutput()); + + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true) + .testName("cLast=" + cLast + ", yLast=" + yLast) + ); + + assertNull(err); + } + } } - @Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824 + @Test public void LSTMLayerTestCase2() { int bS = 5; int nIn = 3; int numUnits = 7; - int sL = 10; //small just for test + int sL = 3; //small just for test SameDiff sd = SameDiff.create(); @@ -1549,11 +1601,11 @@ public class LayerOpValidation extends BaseOpValidation { // NTS: shape [numExamples, timeLength, inOutSize]
// for bidirectional: // T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) - SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, sL, bS, nIn)); + SDVariable in = sd.var("in", Nd4j.rand(DataType.DOUBLE, sL, bS, nIn)); - SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); - SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); + SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits)); + SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits)); LSTMLayerConfig c = LSTMLayerConfig.builder() .lstmdataformat(LSTMDataFormat.TNS) @@ -1569,8 +1621,8 @@ public class LayerOpValidation extends BaseOpValidation { LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer( in, cLast, yLast, null, LSTMLayerWeights.builder() - .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits))) - .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits))) + .weights(sd.var("weights", Nd4j.rand(DataType.DOUBLE, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.rand(DataType.DOUBLE, numUnits, 4 * numUnits))) .build(), c), c); @@ -1578,14 +1630,22 @@ public class LayerOpValidation extends BaseOpValidation { long[] out = new long[]{sL, bS, numUnits}; assertArrayEquals(out, outputs.getOutput().eval().shape()); + sd.setLossVariables(outputs.getOutput()); + + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true) + ); + + assertNull(err); + } - @Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824 + @Test public void LSTMLayerTestCase3() { int bS = 5; int nIn = 3; int numUnits = 7; - int sL = 10; //small just for test + int sL = 3; //small just for test SameDiff sd = SameDiff.create(); @@ -1599,14 +1659,14 @@ public class LayerOpValidation extends BaseOpValidation { // NTS: shape [numExamples, timeLength, inOutSize]
// for bidirectional: // T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) - SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, sL, nIn)); + SDVariable in = sd.var("in", Nd4j.rand(DataType.DOUBLE, bS, sL, nIn)); // when directionMode >= 2 (BIDIR_CONCAT=3) // Wx, Wr [2, nIn, 4*nOut] // hI, cI [2, bS, nOut] - SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits)); - SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits)); + SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, 2, bS, numUnits)); + SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.DOUBLE, 2, bS, numUnits)); LSTMLayerConfig c = LSTMLayerConfig.builder() .lstmdataformat(LSTMDataFormat.NTS) @@ -1622,8 +1682,8 @@ public class LayerOpValidation extends BaseOpValidation { LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(new String[]{"out"}, in, cLast, yLast, null, LSTMLayerWeights.builder() - .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, 2, nIn, 4 * numUnits))) - .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, 2, numUnits, 4 * numUnits))) + .weights(sd.var("weights", Nd4j.rand(DataType.DOUBLE, 2, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.rand(DataType.DOUBLE, 2, numUnits, 4 * numUnits))) .build(), c), c); @@ -1631,5 +1691,17 @@ public class LayerOpValidation extends BaseOpValidation { long[] out = new long[]{bS, sL, 2 * numUnits}; assertArrayEquals(out, outputs.getOutput().eval().shape()); + + sd.setLossVariables(outputs.getOutput()); + + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true) + ); + + assertNull(err); } + + + + } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 27a15b517..1812c62b0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -30,28 +30,31 @@ import org.nd4j.enums.DataFormat; import org.nd4j.autodiff.validation.OpTestCase; import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.TestCase; -import org.nd4j.linalg.api.blas.params.MMulTranspose; +import org.nd4j.enums.ImageResizeMethod; +import org.nd4j.enums.PartitionMode; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.image.ImageResize; import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace; import org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth; import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod; import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication; import org.nd4j.linalg.api.ops.impl.shape.Cross; +import org.nd4j.linalg.api.ops.impl.shape.MergeAvg; +import org.nd4j.linalg.api.ops.impl.shape.MergeMax; +import org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup; import org.nd4j.linalg.api.ops.impl.transforms.Pad; -import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; -import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Max; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Min; -import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; +import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm; +import org.nd4j.linalg.api.ops.impl.transforms.custom.*; import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp; import org.nd4j.linalg.api.ops.impl.transforms.strict.*; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.function.Function; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.conditions.Condition; @@ -104,7 +107,7 @@ public class TransformOpValidation extends BaseOpValidation { List failed = new ArrayList<>(); - for( int i=0; i<11; i++ ) { + for (int i = 0; i < 11; i++) { for (char inOrder : new char[]{'c', 'f'}) { SameDiff sd = SameDiff.create(); @@ -114,7 +117,7 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable out; String msg; - switch (i){ + switch (i) { case 0: out = in.mul(2); tc.expectedOutput(out.name(), inArr.mul(2)); @@ -146,7 +149,7 @@ public class TransformOpValidation extends BaseOpValidation { msg = "rsub - " + inOrder; break; case 6: - out = sd.math().pow(in,2); + out = sd.math().pow(in, 2); tc.expectedOutput(out.name(), Transforms.pow(inArr, 2)); msg = "pow - " + inOrder; break; @@ -183,7 +186,7 @@ public class TransformOpValidation extends BaseOpValidation { log.info("Starting test: " + msg); String err = OpValidation.validate(tc, true); - if(err != null){ + if (err != null) { failed.add(err); } } @@ -192,10 +195,10 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testScalarMulCF(){ + public void testScalarMulCF() { - INDArray in = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4); - INDArray outC = Nd4j.createUninitialized(3,4); + INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); + INDArray outC = Nd4j.createUninitialized(3, 4); INDArray outF = Nd4j.createUninitialized(3, 4); Nd4j.getExecutioner().exec(new ScalarMultiplication(in, null, outC, 2.0)); @@ -206,9 +209,9 @@ public class TransformOpValidation extends BaseOpValidation { @Test - public void testScalarMulCF2(){ + public void testScalarMulCF2() { - INDArray in = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4); + INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); INDArray outC = Nd4j.getExecutioner().exec(new ScalarMultiplication(in.dup('c'), 2.0)); INDArray outF = Nd4j.getExecutioner().exec(new ScalarMultiplication(in.dup('f'), 2.0)); @@ -221,7 +224,7 @@ public class TransformOpValidation extends BaseOpValidation { INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3}); INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3}); - INDArray expOut = Nd4j.create(DataType.DOUBLE,1, 3); + INDArray expOut = Nd4j.create(DataType.DOUBLE, 1, 3); val op = new Cross(a, b, expOut); Nd4j.getExecutioner().exec(op); @@ -239,8 +242,8 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable loss = sd.mean("loss", t); String err = OpValidation.validate(new TestCase(sd) - .expectedOutput("cross", expOut) - .gradientCheck(true)); + .expectedOutput("cross", expOut) + .gradientCheck(true)); assertNull(err, err); } @@ -263,7 +266,7 @@ public class TransformOpValidation extends BaseOpValidation { sd.associateArrayWithVariable(input, sdInput); SDVariable t = sd.cnn().spaceToDepth("std", sdInput, blockSize, DataFormat.NHWC); - //new SpaceToDepth(sd, sdInput, blockSize, dataFormat).outputVariable(); + //new SpaceToDepth(sd, sdInput, blockSize, dataFormat).outputVariable(); SDVariable loss = sd.mean("loss", t); String err = OpValidation.validate(new TestCase(sd) @@ -291,7 +294,7 @@ public class TransformOpValidation extends BaseOpValidation { sd.associateArrayWithVariable(input, sdInput); SDVariable t = sd.cnn().depthToSpace("dts", sdInput, blockSize, DataFormat.NHWC); - SDVariable loss = sd.mean("loss", t); + SDVariable loss = sd.mean("loss", t); String err = OpValidation.validate(new TestCase(sd) .expectedOutput("dts", expOut) @@ -415,7 +418,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testDynamicPartition2(){ + public void testDynamicPartition2() { INDArray data = Nd4j.createFromArray(2, 1, 2, 0); INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") @@ -448,7 +451,7 @@ public class TransformOpValidation extends BaseOpValidation { .addOutputs(expOut).build(); Nd4j.getExecutioner().exec(dynamicStitch); - INDArray expOut2 = Nd4j.create(new double[]{5,1,7,2,3,4}); + INDArray expOut2 = Nd4j.create(new double[]{5, 1, 7, 2, 3, 4}); assertEquals(expOut2, expOut); SDVariable in1 = sd.var("in1", ia); @@ -473,11 +476,11 @@ public class TransformOpValidation extends BaseOpValidation { public void testDiag() { SameDiff sd = SameDiff.create(); - INDArray ia = Nd4j.create(new double[]{1, 2}, new int[] {2}); + INDArray ia = Nd4j.create(new double[]{1, 2}, new int[]{2}); SDVariable in = sd.var("in", DataType.DOUBLE, new long[]{2}); - INDArray expOut = Nd4j.create(new double[][]{{1, 0},{0,2}}); + INDArray expOut = Nd4j.create(new double[][]{{1, 0}, {0, 2}}); - INDArray expOut2 = Nd4j.create(DataType.DOUBLE, 2,2); + INDArray expOut2 = Nd4j.create(DataType.DOUBLE, 2, 2); DynamicCustomOp diag = DynamicCustomOp.builder("diag").addInputs(ia).addOutputs(expOut2).build(); Nd4j.getExecutioner().exec(diag); @@ -485,7 +488,7 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable t = sd.math().diag("diag", in); - SDVariable loss = sd.standardDeviation("loss", t,false,0, 1); + SDVariable loss = sd.standardDeviation("loss", t, false, 0, 1); sd.associateArrayWithVariable(ia, in); @@ -499,7 +502,7 @@ public class TransformOpValidation extends BaseOpValidation { public void testDiagPart() { SameDiff sd = SameDiff.create(); - INDArray input = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape(4,4); + INDArray input = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4); INDArray expOut = Nd4j.create(new float[]{1, 6, 11, 16}).castTo(DataType.DOUBLE); SDVariable in = sd.var("in", input); @@ -515,26 +518,26 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testEye(){ - int[] rows = new int[]{3,3,3,3}; - int[] cols = new int[]{3,2,2,2}; - int[][] batch = new int[][]{{}, {}, {4}, {3,3}}; + public void testEye() { + int[] rows = new int[]{3, 3, 3, 3}; + int[] cols = new int[]{3, 2, 2, 2}; + int[][] batch = new int[][]{{}, {}, {4}, {3, 3}}; INDArray[] expOut = new INDArray[4]; expOut[0] = Nd4j.eye(3).castTo(DataType.DOUBLE); - expOut[1] = Nd4j.create(new double[][]{{1,0},{0,1},{0,0}}); - expOut[2] = Nd4j.create(DataType.DOUBLE, 4,3,2); - for( int i=0; i<4; i++ ){ + expOut[1] = Nd4j.create(new double[][]{{1, 0}, {0, 1}, {0, 0}}); + expOut[2] = Nd4j.create(DataType.DOUBLE, 4, 3, 2); + for (int i = 0; i < 4; i++) { expOut[2].get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all()).assign(expOut[1]); } - expOut[3] = Nd4j.create(DataType.DOUBLE, 3,3,3,2); - for( int i=0; i<3; i++ ){ - for( int j=0; j<3; j++ ) { + expOut[3] = Nd4j.create(DataType.DOUBLE, 3, 3, 3, 2); + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { expOut[3].get(NDArrayIndex.point(i), NDArrayIndex.point(j), NDArrayIndex.all(), NDArrayIndex.all()).assign(expOut[1]); } } - for(int i=0; i<3; i++ ) { + for (int i = 0; i < 3; i++) { SameDiff sd = SameDiff.create(); SDVariable eye = sd.math().eye("e", rows[i], cols[i], DataType.DOUBLE, batch[i]); @@ -549,15 +552,15 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testEyeShape(){ + public void testEyeShape() { DynamicCustomOp dco = DynamicCustomOp.builder("eye") - .addIntegerArguments(3,3) + .addIntegerArguments(3, 3) //.addIntegerArguments(-99,3,3) //Also fails .build(); val list = Nd4j.getExecutioner().calculateOutputShape(dco); assertEquals(1, list.size()); //Fails here - empty list - assertArrayEquals(new long[]{3,3}, list.get(0).getShape()); + assertArrayEquals(new long[]{3, 3}, list.get(0).getShape()); } @Test @@ -687,7 +690,7 @@ public class TransformOpValidation extends BaseOpValidation { break; case 23: //TODO SHOULDN'T THIS HAVE A DIMENSION ARG??? - t = sd.nn().softmax(in,-1); + t = sd.nn().softmax(in, -1); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]); break; @@ -756,7 +759,7 @@ public class TransformOpValidation extends BaseOpValidation { tc.expectedOutput(t.name(), Transforms.leakyRelu(ia, true)); break; case 39: - if(OpValidationSuite.IGNORE_FAILING) + if (OpValidationSuite.IGNORE_FAILING) continue; t = sd.nn().logSoftmax(in); ia = Nd4j.rand(minibatch, nOut).muli(10).subi(5); @@ -852,17 +855,17 @@ public class TransformOpValidation extends BaseOpValidation { tc.expectedOutput(t.name(), expOut51); break; case 52: - if(OpValidationSuite.IGNORE_FAILING){ + if (OpValidationSuite.IGNORE_FAILING) { continue; } boolean ex = false; boolean revBool = false; t = sd.cumprod(in, ex, revBool, 0); INDArray expOut52 = Nd4j.create(DataType.DOUBLE, ia.shape()); - for( int s0=0; s0 failed = new ArrayList<>(); - for( int i=0; i<4; i++ ){ + for (int i = 0; i < 4; i++) { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", 4); @@ -1248,26 +1252,26 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable out; INDArray exp; INDArray inArr; - switch (i){ + switch (i) { case 0: - inArr = Nd4j.create(new double[]{10,Double.POSITIVE_INFINITY, 0, Double.NEGATIVE_INFINITY}); - exp = Nd4j.create(new boolean[]{true,false,true,false}); + inArr = Nd4j.create(new double[]{10, Double.POSITIVE_INFINITY, 0, Double.NEGATIVE_INFINITY}); + exp = Nd4j.create(new boolean[]{true, false, true, false}); out = sd.math().isFinite(in); break; case 1: - inArr = Nd4j.create(new double[]{10,Double.POSITIVE_INFINITY, 0, Double.NEGATIVE_INFINITY}); - exp = Nd4j.create(new boolean[]{false,true,false,true}); + inArr = Nd4j.create(new double[]{10, Double.POSITIVE_INFINITY, 0, Double.NEGATIVE_INFINITY}); + exp = Nd4j.create(new boolean[]{false, true, false, true}); out = sd.math().isInfinite(in); break; case 2: //TODO: IsMax supports both bool and float out: https://github.com/deeplearning4j/deeplearning4j/issues/6872 - inArr = Nd4j.create(new double[]{-3,5,0,2}); - exp = Nd4j.create(new boolean[]{false,true,false,false}); + inArr = Nd4j.create(new double[]{-3, 5, 0, 2}); + exp = Nd4j.create(new boolean[]{false, true, false, false}); out = sd.math().isMax(in); break; case 3: - inArr = Nd4j.create(new double[]{0,Double.NaN,10,Double.NaN}); - exp = Nd4j.create(new boolean[]{false,true,false,true}); + inArr = Nd4j.create(new double[]{0, Double.NaN, 10, Double.NaN}); + exp = Nd4j.create(new boolean[]{false, true, false, true}); out = sd.math().isNaN(in); break; default: @@ -1284,7 +1288,7 @@ public class TransformOpValidation extends BaseOpValidation { in.setArray(inArr); String err = OpValidation.validate(tc, true); - if(err != null){ + if (err != null) { failed.add(err); } } @@ -1292,11 +1296,11 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testReplaceWhereScalar(){ - for(Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}){ + public void testReplaceWhereScalar() { + for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) { log.info("Testing condition: " + c.getClass().getSimpleName()); - INDArray inArr = Nd4j.rand(DataType.DOUBLE, 3,4); + INDArray inArr = Nd4j.rand(DataType.DOUBLE, 3, 4); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); SDVariable where = sd.replaceWhere(in, 10, c); @@ -1314,10 +1318,10 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testReplaceWhereArray(){ - for(Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}){ + public void testReplaceWhereArray() { + for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) { - INDArray inArr = Nd4j.rand(3,4); + INDArray inArr = Nd4j.rand(3, 4); INDArray inArr2 = Nd4j.valueArrayOf(3, 4, 10); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); @@ -1356,7 +1360,7 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable input = sameDiff.var("x", inputs.get("x")); SDVariable sigmoid = sameDiff.nn().sigmoid(input); SDVariable sum = sameDiff.sum(sigmoid, Integer.MAX_VALUE); - Map m = sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet()); + Map m = sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet()); INDArray arr = m.get(input.name()); assertTrue(Nd4j.create(new double[][]{ {0.1966, 0.1050}, @@ -1375,22 +1379,22 @@ public class TransformOpValidation extends BaseOpValidation { }*/ @Test - public void testRank0EdgeCase(){ + public void testRank0EdgeCase() { SameDiff sd = SameDiff.create(); SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))); double d0 = v1.eval().getDouble(0); assertEquals(8, d0, 0); SDVariable v2 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))).div(2.0); - Map m = sd.outputAll(Collections.emptyMap()); + Map m = sd.outputAll(Collections.emptyMap()); double d1 = m.get(v2.name()).getDouble(0); assertEquals(4, d1, 0); } @Test - public void testAtan2BroadcastShape(){ - INDArray arr1 = Nd4j.create(new long[]{3,1,4}); - INDArray arr2 = Nd4j.create(new long[]{1,2,4}); + public void testAtan2BroadcastShape() { + INDArray arr1 = Nd4j.create(new long[]{3, 1, 4}); + INDArray arr2 = Nd4j.create(new long[]{1, 2, 4}); DynamicCustomOp op = DynamicCustomOp.builder("tf_atan2") .addInputs(arr1, arr2) @@ -1399,15 +1403,15 @@ public class TransformOpValidation extends BaseOpValidation { val outShapes = Nd4j.getExecutioner().calculateOutputShape(op); assertEquals(1, outShapes.size()); - assertArrayEquals(Arrays.toString(outShapes.get(0).getShape()), new long[]{3,2,4}, outShapes.get(0).getShape()); + assertArrayEquals(Arrays.toString(outShapes.get(0).getShape()), new long[]{3, 2, 4}, outShapes.get(0).getShape()); } @Test - public void testBooleanAnd(){ + public void testBooleanAnd() { Nd4j.setDataType(DataType.FLOAT); - INDArray arr1 = Nd4j.create(new long[]{3,4}); - INDArray arr2 = Nd4j.create(new long[]{3,4}); - INDArray out = Nd4j.create(new long[]{3,4}); + INDArray arr1 = Nd4j.create(new long[]{3, 4}); + INDArray arr2 = Nd4j.create(new long[]{3, 4}); + INDArray out = Nd4j.create(new long[]{3, 4}); DynamicCustomOp op = DynamicCustomOp.builder("boolean_and") .addInputs(arr1, arr2) @@ -1417,8 +1421,8 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testScatterOpsScalar(){ - for(String s : new String[]{"add", "sub", "mul", "div"}) { + public void testScatterOpsScalar() { + for (String s : new String[]{"add", "sub", "mul", "div"}) { INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3); INDArray indices = Nd4j.scalar(5); INDArray upd = Nd4j.create(new double[]{10, 20, 30}); @@ -1428,7 +1432,7 @@ public class TransformOpValidation extends BaseOpValidation { // INDArray upd = Nd4j.create(new double[]{10, 20, 30}, new int[]{1, 3}); INDArray exp = ref.dup(); - switch (s){ + switch (s) { case "add": exp.getRow(5).addi(upd); break; @@ -1462,9 +1466,9 @@ public class TransformOpValidation extends BaseOpValidation { @Ignore("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540") @Test - public void testPad(){ + public void testPad() { INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0); - INDArray pad = Nd4j.create(new double[]{1,1}, new long[]{1,2}).castTo(DataType.LONG); + INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG); INDArray value = Nd4j.scalar(10.0); INDArray out = Nd4j.create(new long[]{7}); @@ -1482,18 +1486,18 @@ public class TransformOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable s = sd.var("in", in); - SDVariable padded = sd.f().pad(s, sd.constant(pad), Pad.Mode.CONSTANT,10.0); + SDVariable padded = sd.f().pad(s, sd.constant(pad), Pad.Mode.CONSTANT, 10.0); String err2 = OpValidation.validate(new TestCase(sd).expected(padded, exp).gradientCheck(false)); assertNull(err2); } @Test - public void testMirrorPad(){ - INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2,3); - INDArray pad = Nd4j.create(new double[][]{{1,1},{2,2}}).castTo(DataType.INT); + public void testMirrorPad() { + INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); + INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT); - INDArray out = Nd4j.create(DataType.DOUBLE, 4,7); + INDArray out = Nd4j.create(DataType.DOUBLE, 4, 7); DynamicCustomOp op = DynamicCustomOp.builder("mirror_pad") .addInputs(in, pad) @@ -1509,24 +1513,24 @@ public class TransformOpValidation extends BaseOpValidation { {6, 5, 4, 5, 6, 5, 4}, {3, 2, 1, 2, 3, 2, 1}}); String err = OpValidation.validate(new OpTestCase(op) - .expectedOutput(0, exp)); + .expectedOutput(0, exp)); assertNull(err); SameDiff sd = SameDiff.create(); SDVariable s = sd.var("in", in); - SDVariable padded = sd.f().pad(s, sd.constant(Nd4j.createFromArray(new int[][]{{1,1},{2,2}})), Pad.Mode.REFLECT, 0.0); + SDVariable padded = sd.f().pad(s, sd.constant(Nd4j.createFromArray(new int[][]{{1, 1}, {2, 2}})), Pad.Mode.REFLECT, 0.0); String err2 = OpValidation.validate(new TestCase(sd).expected(padded, exp).gradientCheck(false)); assertNull(err2); } @Test - public void testMirrorPad2(){ - INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2,3); - INDArray pad = Nd4j.create(new double[][]{{1,1},{2,2}}).castTo(DataType.INT); + public void testMirrorPad2() { + INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); + INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT); - INDArray out = Nd4j.create(DataType.DOUBLE, 4,7); + INDArray out = Nd4j.create(DataType.DOUBLE, 4, 7); DynamicCustomOp op = DynamicCustomOp.builder("mirror_pad") .addInputs(in, pad) @@ -1548,11 +1552,11 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMirrorPadSymmetric(){ - INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3,4); - INDArray pad = Nd4j.create(new double[][]{{1,1},{1,1}}).castTo(DataType.INT); + public void testMirrorPadSymmetric() { + INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4); + INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT); - INDArray out = Nd4j.create(DataType.DOUBLE, 5,6); + INDArray out = Nd4j.create(DataType.DOUBLE, 5, 6); DynamicCustomOp op = DynamicCustomOp.builder("mirror_pad") .addInputs(in, pad) @@ -1563,11 +1567,11 @@ public class TransformOpValidation extends BaseOpValidation { Nd4j.getExecutioner().exec(op); INDArray exp = Nd4j.create(new double[][]{ - { 1, 1, 2, 3, 4, 4}, - { 1, 1, 2, 3, 4, 4}, - { 5, 5, 6, 7, 8, 8}, - { 9, 9, 10, 11, 12, 12}, - { 9, 9, 10, 11, 12, 12}}); + {1, 1, 2, 3, 4, 4}, + {1, 1, 2, 3, 4, 4}, + {5, 5, 6, 7, 8, 8}, + {9, 9, 10, 11, 12, 12}, + {9, 9, 10, 11, 12, 12}}); String err = OpValidation.validate(new OpTestCase(op) .expectedOutput(0, exp)); @@ -1575,7 +1579,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testUnique(){ + public void testUnique() { INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4}); INDArray expUnique = Nd4j.create(new double[]{3, 4, 1, 0, 2}); @@ -1597,7 +1601,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testTopK(){ + public void testTopK() { OpValidationSuite.ignoreFailing(); //Can't assume sorted here INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8}); @@ -1607,7 +1611,7 @@ public class TransformOpValidation extends BaseOpValidation { INDArray expTopK_sorted = Nd4j.create(new double[]{9, 8, 7, 6, 5}); INDArray expIndices_sorted = Nd4j.create(new double[]{8, 9, 0, 7, 4}); - for(boolean sort : new boolean[]{false, true}) { + for (boolean sort : new boolean[]{false, true}) { INDArray outUnique = Nd4j.create(expTopK.shape()); INDArray outUniqueIdxs = Nd4j.create(expIndices.shape()); @@ -1626,7 +1630,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testTopK1(){ + public void testTopK1() { INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); INDArray k = Nd4j.scalar(1); INDArray outValue = Nd4j.create(DataType.DOUBLE, 1); @@ -1648,13 +1652,13 @@ public class TransformOpValidation extends BaseOpValidation { @Test public void testInTopK() { - for( int k=4; k>= 1; k--){ + for (int k = 4; k >= 1; k--) { log.info("Testing: k=" + k); INDArray in = Nd4j.linspace(1, 20, 20, DataType.DOUBLE).reshape(4, 5); INDArray idxs = Nd4j.create(new double[]{1, 2, 3, 4}).castTo(DataType.INT); INDArray expOut; - switch (k){ + switch (k) { case 4: expOut = Nd4j.create(new boolean[]{true, true, true, true}); break; @@ -1672,7 +1676,6 @@ public class TransformOpValidation extends BaseOpValidation { } - INDArray out = Nd4j.create(DataType.BOOL, expOut.shape()); DynamicCustomOp op = DynamicCustomOp.builder("in_top_k") @@ -1689,14 +1692,14 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testZeta(){ + public void testZeta() { OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182 - INDArray x = Nd4j.rand(3,4).addi(1.0); - INDArray q = Nd4j.rand(3,4); + INDArray x = Nd4j.rand(3, 4).addi(1.0); + INDArray q = Nd4j.rand(3, 4); - INDArray out = Nd4j.create(3,4); + INDArray out = Nd4j.create(3, 4); DynamicCustomOp op = DynamicCustomOp.builder("zeta") - .addInputs(x,q) + .addInputs(x, q) .addOutputs(out) .build(); @@ -1706,7 +1709,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMaxEmptyScalar(){ + public void testMaxEmptyScalar() { INDArray empty = Nd4j.empty(DataType.FLOAT); INDArray scalar = Nd4j.scalar(1.0f); @@ -1723,7 +1726,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testBroadcastEmpty(){ + public void testBroadcastEmpty() { // Nd4j.getExecutioner().enableVerboseMode(true); // Nd4j.getExecutioner().enableDebugMode(true); //Check broadcast behaviour with empty arrays. The idea is to match TF import behaviour, for import @@ -1745,11 +1748,11 @@ public class TransformOpValidation extends BaseOpValidation { out = sess.run([out]) */ - for( int i=0; i<3; i++ ){ - for(boolean scalar : new boolean[]{true, false}){ - INDArray x = scalar ? Nd4j.scalar(2f) : Nd4j.create(DataType.FLOAT, 3,4); - INDArray y = scalar ? Nd4j.scalar(3f) : Nd4j.create(DataType.FLOAT, 3,4); - switch (i){ + for (int i = 0; i < 3; i++) { + for (boolean scalar : new boolean[]{true, false}) { + INDArray x = scalar ? Nd4j.scalar(2f) : Nd4j.create(DataType.FLOAT, 3, 4); + INDArray y = scalar ? Nd4j.scalar(3f) : Nd4j.create(DataType.FLOAT, 3, 4); + switch (i) { case 0: //x only empty x = Nd4j.empty(DataType.FLOAT); @@ -1768,16 +1771,16 @@ public class TransformOpValidation extends BaseOpValidation { } - for( String opName : new String[]{"maximum", "minimum", "add", "subtract", "multiply", "divide", "assign", + for (String opName : new String[]{"maximum", "minimum", "add", "subtract", "multiply", "divide", "assign", "boolean_and", "boolean_or", "boolean_xor", "tf_atan2", "equals", "floordiv", "floormod", "greater", "greater_equal", "less", "less_equal", "mod", "not_equals", "realdiv", "reversedivide", "reversesubtract", - "squaredsubtract", "truncatediv"} ){ + "squaredsubtract", "truncatediv"}) { // log.info("Starting op: {}, case {} - x.isScalar()={}, x.isEmpty()={}, y.isScalar()={}, y.isEmpty()={}", opName, i, // x.isScalar(), x.isEmpty(), y.isScalar(), y.isEmpty()); DynamicCustomOp op = DynamicCustomOp.builder(opName) - .addInputs(x,y) + .addInputs(x, y) .build(); List l = op.calculateOutputShape(); @@ -1786,7 +1789,7 @@ public class TransformOpValidation extends BaseOpValidation { boolean empty = l.get(0).isEmpty(); boolean isBool = isBoolBroadcast(opName); - if(isBool){ + if (isBool) { assertEquals(DataType.BOOL, l.get(0).dataType()); } else { assertEquals(DataType.FLOAT, l.get(0).dataType()); @@ -1805,8 +1808,8 @@ public class TransformOpValidation extends BaseOpValidation { } } - private static boolean isBoolBroadcast(String opName){ - if(opName.startsWith("greater") || opName.startsWith("less") || opName.contains("equals")) + private static boolean isBoolBroadcast(String opName) { + if (opName.startsWith("greater") || opName.startsWith("less") || opName.contains("equals")) return true; //Note that "boolean" ops are inherit return false; @@ -1852,7 +1855,7 @@ public class TransformOpValidation extends BaseOpValidation { public void testStandardizeNoDeviation() { final INDArray random = Nd4j.rand(new int[]{10, 4}); for (int i = 0; i < 4; i++) { - random.putScalar(1,i, 7); + random.putScalar(1, i, 7); } final int[] axis = new int[]{1}; @@ -1875,7 +1878,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMatMulTensor(){ + public void testMatMulTensor() { final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5}); final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6}); @@ -1895,20 +1898,76 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMatMulTensorTranspose(){ - for(boolean transposeA: new boolean[]{false, true}) { + public void testMatMulTensorTranspose() { + for (boolean transposeA : new boolean[]{false, true}) { for (boolean transposeB : new boolean[]{false, true}) { for (boolean transposeResult : new boolean[]{false, true}) { log.info("Testing with transposeA={}; transposeB={}; transposeResult={};", transposeA, transposeB, transposeResult); int m = 0, n = 0, k = 0, l = 0, i = 0, j = 0; - if(!transposeA && !transposeB && !transposeResult){ m = 4; n = 5; k = 5; l = 6; i = 4; j = 6;} - if(!transposeA && transposeB && !transposeResult){ m = 4; n = 5; k = 6; l = 5; i = 4; j = 6;} - if(!transposeA && !transposeB && transposeResult){ m = 4; n = 5; k = 5; l = 6; i = 6; j = 4;} - if(!transposeA && transposeB && transposeResult){ m = 4; n = 5; k = 6; l = 5; i = 6; j = 4;} - if( transposeA && !transposeB && !transposeResult){ m = 5; n = 4; k = 5; l = 6; i = 4; j = 6;} - if( transposeA && transposeB && !transposeResult){ m = 5; n = 4; k = 6; l = 5; i = 4; j = 6;} - if( transposeA && !transposeB && transposeResult){ m = 5; n = 4; k = 5; l = 6; i = 6; j = 4;} - if( transposeA && transposeB && transposeResult){ m = 5; n = 4; k = 6; l = 5; i = 6; j = 4;} + if (!transposeA && !transposeB && !transposeResult) { + m = 4; + n = 5; + k = 5; + l = 6; + i = 4; + j = 6; + } + if (!transposeA && transposeB && !transposeResult) { + m = 4; + n = 5; + k = 6; + l = 5; + i = 4; + j = 6; + } + if (!transposeA && !transposeB && transposeResult) { + m = 4; + n = 5; + k = 5; + l = 6; + i = 6; + j = 4; + } + if (!transposeA && transposeB && transposeResult) { + m = 4; + n = 5; + k = 6; + l = 5; + i = 6; + j = 4; + } + if (transposeA && !transposeB && !transposeResult) { + m = 5; + n = 4; + k = 5; + l = 6; + i = 4; + j = 6; + } + if (transposeA && transposeB && !transposeResult) { + m = 5; + n = 4; + k = 6; + l = 5; + i = 4; + j = 6; + } + if (transposeA && !transposeB && transposeResult) { + m = 5; + n = 4; + k = 5; + l = 6; + i = 6; + j = 4; + } + if (transposeA && transposeB && transposeResult) { + m = 5; + n = 4; + k = 6; + l = 5; + i = 6; + j = 4; + } final INDArray a = Nd4j.rand(new int[]{1, 2, 3, m, n}); final INDArray b = Nd4j.rand(new int[]{1, 2, 3, k, l}); @@ -1932,7 +1991,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testSoftmaxCF(){ + public void testSoftmaxCF() { INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5); INDArray arrF = arrC.dup('f'); @@ -1953,7 +2012,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testLogSumExp(){ + public void testLogSumExp() { Nd4j.getRandom().setSeed(12345); INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4); SameDiff sd = SameDiff.create(); @@ -1968,9 +2027,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testLogSumExp2(){ + public void testLogSumExp2() { - for( int dim=0; dim<=2; dim++ ) { + for (int dim = 0; dim <= 2; dim++) { Nd4j.getRandom().setSeed(12345); INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 3, 4, 5); SameDiff sd = SameDiff.create(); @@ -1986,4 +2045,174 @@ public class TransformOpValidation extends BaseOpValidation { .gradientCheck(true)); } } -} + + + @Test + public void testCRELU() { + + Nd4j.getRandom().setSeed(12345); + INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2); + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var(inputArr); + + SDVariable crelu = new CReLU(sd, in).outputVariable(); + INDArray expected = Nd4j.concat(1, Nd4j.nn.relu(inputArr, 0), Nd4j.nn.relu(inputArr.neg(), 0)); + + String err = OpValidation.validate(new TestCase(sd) + .expectedOutput("crelu", expected) + .gradientCheck(true) + ); + + assertNull(err); + + + + } + + @Test + public void testClipByAvgNorm() { + + Nd4j.getRandom().setSeed(12345); + INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2, 2); + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var(inputArr); + SDVariable out = new ClipByAvgNorm(sd, in, 1e-2, 0, 1, 2).outputVariable(); + SDVariable expected = sd.math.clipByNorm(in, 1e-2, 0, 1, 2).mul(inputArr.length()); + + SDVariable loss = sd.standardDeviation("loss", out, true); + loss.markAsLoss(); + + String err = OpValidation.validate(new TestCase(sd) + .expectedOutput("clipbyavgnorm", expected.eval()) + .gradientCheck(false) + ); + assertNull(err); + + } + + + @Test + public void testEmbeddingLookup() { + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable input = sd.var("in", Nd4j.rand(1024, 10)); + SDVariable indices = sd.constant("indices", Nd4j.createFromArray(new long[]{0, 5, 17, 33})); + SDVariable out = new EmbeddingLookup(sd, input, indices, PartitionMode.MOD).outputVariable(); + // should be matrix of shape [4, 10] + assertArrayEquals(new long[]{4, 10}, out.eval().shape()); + + } + + @Test + public void testImageResize() { + + //TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea + + for (ImageResizeMethod method : ImageResizeMethod.values()) { + if (method==ImageResizeMethod.ResizeLanczos5 || method==ImageResizeMethod.ResizeArea || method==ImageResizeMethod.ResizeMitchelcubic) + {continue;} + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + boolean preserveAspectRatio = true; + boolean antialias = true; + SDVariable inputImage = sd.var(Nd4j.rand(1, 5, 5, 3)); + // NHWC format + long[] expectedShape = new long[]{1, 3, 3, 3}; + SDVariable requestedSize = sd.constant(Nd4j.createFromArray( new long[]{3, 3})); + + Function checkFunction = in -> { + boolean shapeOk = Arrays.equals(expectedShape, in.shape()); + if (shapeOk) return null; + return "Failed: shape differs - expected " + Arrays.toString(expectedShape) + " vs " + Arrays.toString(in.shape()) + " on method " + method; + }; + + + SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable(); + + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(false) + .expected("image_resize", checkFunction)); + + assertNull(err); + + + } + } + + + + + @Test + public void testMaximumBp() { + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable inputX = sd.var(Nd4j.rand(2, 3)); + SDVariable inputY = sd.var(Nd4j.rand(2, 3)); + + + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd, inputX, inputY).outputVariable(); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true)); + assertNull(err); + + + } + + @Test + public void testMergeAddBp() { + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable inputX = sd.var(Nd4j.rand(2, 3)); + SDVariable inputY = sd.var(Nd4j.rand(2, 3)); + SDVariable inputZ = sd.var(Nd4j.rand(2, 3)); + SDVariable out = new MergeAddOp(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable(); + out.markAsLoss(); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true)); + assertNull(err); + + + } + + @Test + public void testMergeMaxBp() { + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable inputX = sd.var(Nd4j.rand(2, 3)); + SDVariable inputY = sd.var(Nd4j.rand(2, 3)); + SDVariable inputZ = sd.var(Nd4j.rand(2, 3)); + SDVariable out = new MergeMax(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable(); + out.markAsLoss(); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true)); + assertNull(err); + + + } + + + @Test + public void testMergeAvgBp() { + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable inputX = sd.var(Nd4j.rand(2, 3)); + SDVariable inputY = sd.var(Nd4j.rand(2, 3)); + SDVariable inputZ = sd.var(Nd4j.rand(2, 3)); + SDVariable out = new MergeAvg(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable(); + out.markAsLoss(); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true)); + assertNull(err); + + + } + + + } + From 18d4eaa68deb9819525b092fe6091e66ecfbbb6b Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 17 Apr 2020 19:47:57 +1000 Subject: [PATCH 09/18] DL4J SameDiff loss function (#251) * Copied and pasted RegressionTest100b4.java to RegressionTest100b6.java with renamed b4->b6 * SameDiffLoss draft * very very draft * Copied and pasted RegressionTest100b4.java to RegressionTest100b6.java with renamed b4->b6 * temporary commit for clarification Signed-off-by: atuzhykov * temporary commit for clarification v2 Signed-off-by: atuzhykov * temporary commit for clarification v3 Signed-off-by: atuzhykov * temporary commit for clarification v3 Signed-off-by: atuzhykov * Copied and pasted RegressionTest100b4.java to RegressionTest100b6.java with renamed b4->b6 * very very draft * temporary commit for clarification Signed-off-by: atuzhykov * temporary commit for clarification v2 Signed-off-by: atuzhykov * temporary commit for clarification v3 Signed-off-by: atuzhykov * temporary commit for clarification v3 Signed-off-by: atuzhykov * SDLoss after requested changes but with questions in comments Signed-off-by: Andrii Tuzhykov * added requested changes * small fixes Signed-off-by: Andrii Tuzhykov * Fixes Signed-off-by: Alex Black * Javadoc Signed-off-by: Alex Black * Test fix Signed-off-by: Alex Black Co-authored-by: Andrii Tuzhykov Co-authored-by: atuzhykov Co-authored-by: Andrii Tuzhykov --- .../LossFunctionGradientCheck.java | 19 +- .../gradientcheck/sdlosscustom/SDLossMAE.java | 30 +++ .../gradientcheck/sdlosscustom/SDLossMSE.java | 30 +++ .../regressiontest/RegressionTest100b6.java | 2 +- .../java/org/nd4j/linalg/api/ops/BaseOp.java | 5 +- .../linalg/lossfunctions/SameDiffLoss.java | 186 ++++++++++++++++++ .../linalg/workspace/BasicWorkspaceTests.java | 2 +- 7 files changed, 270 insertions(+), 4 deletions(-) create mode 100644 deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMAE.java create mode 100644 deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMSE.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/SameDiffLoss.java diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java index 632a85e22..91264f51f 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/LossFunctionGradientCheck.java @@ -19,6 +19,8 @@ package org.deeplearning4j.gradientcheck; import lombok.extern.slf4j.Slf4j; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.TestUtils; +import org.deeplearning4j.gradientcheck.sdlosscustom.SDLossMAE; +import org.deeplearning4j.gradientcheck.sdlosscustom.SDLossMSE; import org.deeplearning4j.nn.api.OptimizationAlgorithm; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; @@ -83,7 +85,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), LossMixtureDensity.builder().gaussians(2).labelWidth(3).build(), new LossMultiLabel(), new LossWasserstein(), - new LossSparseMCXENT() + new LossSparseMCXENT(), + new SDLossMAE(), new SDLossMSE() }; Activation[] outputActivationFn = new Activation[] {Activation.SIGMOID, //xent @@ -119,6 +122,12 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { Activation.TANH, // MultiLabel, doesn't require any special activation, but tanh was used in paper Activation.IDENTITY, // Wasserstein Activation.SOFTMAX, //sparse MCXENT + Activation.SOFTMAX, // SDLossMAE + Activation.SIGMOID, // SDLossMAE + Activation.TANH, // SDLossMAE + Activation.SOFTMAX, // SDLossMSE + Activation.SIGMOID, // SDLossMSE + Activation.TANH //SDLossMSE }; int[] nOut = new int[] {1, //xent @@ -154,6 +163,12 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { 10, // MultiLabel 2, // Wasserstein 4, //sparse MCXENT + 3, // SDLossMAE + 3, // SDLossMAE + 3, // SDLossMAE + 3, // SDLossMSE + 3, // SDLossMSE + 3, // SDLossMSE }; int[] minibatchSizes = new int[] {1, 3}; @@ -520,6 +535,8 @@ public class LossFunctionGradientCheck extends BaseDL4JTest { break; case "LossMAE": case "LossMSE": + case "SDLossMAE": + case "SDLossMSE": case "LossL1": case "LossL2": ret[1] = Nd4j.rand(labelsShape).muli(2).subi(1); diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMAE.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMAE.java new file mode 100644 index 000000000..dbef14bf2 --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMAE.java @@ -0,0 +1,30 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.gradientcheck.sdlosscustom; + +import lombok.EqualsAndHashCode; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.lossfunctions.SameDiffLoss; + +@EqualsAndHashCode(callSuper = false) +public class SDLossMAE extends SameDiffLoss { + + @Override + public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) { + return sd.math.abs(labels.sub(layerInput)).mean(1); + } +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMSE.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMSE.java new file mode 100644 index 000000000..6edce7a49 --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/gradientcheck/sdlosscustom/SDLossMSE.java @@ -0,0 +1,30 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.gradientcheck.sdlosscustom; + +import lombok.EqualsAndHashCode; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.lossfunctions.*; + +@EqualsAndHashCode(callSuper = false) +public class SDLossMSE extends SameDiffLoss { + + @Override + public SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels) { + return labels.squaredDifference(layerInput).mean(1); + } +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java index 22ac01c14..e43e4ca74 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/regressiontest/RegressionTest100b6.java @@ -111,7 +111,7 @@ public class RegressionTest100b6 extends BaseDL4JTest { assertEquals(dtype, net.getLayerWiseConfigurations().getDataType()); assertEquals(dtype, net.params().dataType()); boolean eq = outExp.equalsWithEps(outAct, 0.01); - assertTrue(outExp + " vs " + outAct, eq); + assertTrue("Test for dtype: " + dtypeName + " - " + outExp + " vs " + outAct, eq); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java index 30f3d0bf5..5c45ecf50 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOp.java @@ -28,6 +28,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -348,7 +349,9 @@ public abstract class BaseOp extends DifferentialFunction implements Op { if (dimensions == null || dimensions.length == 0) dimensions = new int[]{Integer.MAX_VALUE}; - this.dimensionz = Shape.ndArrayDimFromInt(dimensions); + try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { + this.dimensionz = Shape.ndArrayDimFromInt(dimensions); + } } public INDArray dimensions() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/SameDiffLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/SameDiffLoss.java new file mode 100644 index 000000000..2a3a05663 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/SameDiffLoss.java @@ -0,0 +1,186 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.lossfunctions; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.activations.IActivation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.primitives.Pair; +import java.util.HashMap; +import java.util.Map; + +/** + * SameDiff loss function. + * + * This class can be extended to create Deeplearning4j loss functions by defining one single method only: + * {@link #defineLoss(SameDiff, SDVariable, SDVariable)}. This method is used to define the loss function on a + * per example basis - i.e., the output should be an array with shape [minibatch].
+ *
+ * For example, the mean squared error (MSE) loss function can be defined using:
+ * {@code return labels.squaredDifference(layerInput).mean(1);} + * + */ +public abstract class SameDiffLoss implements ILossFunction { + protected transient SameDiff sd; + protected transient SDVariable scoreVariable; + + protected SameDiffLoss() { + + } + + /** + * Define the loss function.
+ * NOTE: The score on a *per example* basis - should return a SDVariable with shape [minibatch], where out[i] + * is the score for the ith minibatch + * + * @param sd SameDiff instance to define the loss on + * @param layerInput Input to the SameDiff loss function + * @param labels Labels placeholder + * @return The score on a per example basis (SDVariable with shape [minibatch]) + */ + public abstract SDVariable defineLoss(SameDiff sd, SDVariable layerInput, SDVariable labels); + + protected void createSameDiffInstance(DataType dataType){ + sd = SameDiff.create(); + SDVariable layerInput = sd.placeHolder("layerInput", dataType, -1); + SDVariable labels = sd.placeHolder("labels", dataType, -1); + scoreVariable = this.defineLoss(sd, layerInput, labels); + sd.createGradFunction("layerInput"); + } + + /** + * Compute the score (loss function value) for the given inputs. + * + * @param labels Label/expected preOutput + * @param preOutput Output of the model (neural network) + * @param activationFn Activation function that should be applied to preOutput + * @param mask Mask array; may be null + * @param average Whether the score should be averaged (divided by number of rows in labels/preOutput) or not @return Loss function value + */ + @Override + public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) { + if(sd == null){ + createSameDiffInstance(preOutput.dataType()); + } + + INDArray scoreArr = computeScoreArray(labels, preOutput, activationFn, mask); + + double score = scoreArr.sumNumber().doubleValue(); + if (average) { + score /= scoreArr.size(0); + } + return score; + } + + + /** + * Compute the score (loss function value) for each example individually. + * For input [numExamples,nOut] returns scores as a column vector: [numExamples,1] + * + * @param labels Labels/expected output + * @param preOutput Output of the model (neural network) + * @param activationFn Activation function that should be applied to preOutput + * @param mask @return Loss function value for each example; column vector + */ + @Override + public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + if(sd == null){ + createSameDiffInstance(preOutput.dataType()); + } + + Preconditions.checkArgument((labels.size(1) == preOutput.size(1)), "Labels array numColumns (size(1) = %s) does not match output layer number of outputs (nOut = %s)", labels.size(1), preOutput.size(1)); + + INDArray output = activationFn.getActivation(preOutput.dup(), true); + + Map m = new HashMap<>(); + m.put("labels", labels); + m.put("layerInput", output); + + INDArray scoreArr = sd.outputSingle(m,scoreVariable.name()); + + if (mask != null) { + LossUtil.applyMask(scoreArr, mask); + } + return scoreArr; + } + + + /** + * Compute the gradient of the loss function with respect to the inputs: dL/dOutput + * + * @param labels Label/expected output + * @param preOutput Output of the model (neural network), before the activation function is applied + * @param activationFn Activation function that should be applied to preOutput + * @param mask Mask array; may be null + * @return Gradient dL/dPreOut + */ + @Override + public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) { + if(sd == null){ + createSameDiffInstance(preOutput.dataType()); + } + + + Map m = new HashMap<>(); + INDArray output = activationFn.getActivation(preOutput.dup(), true); + m.put("labels", labels); + m.put("layerInput", output); + + Map grads = sd.calculateGradients(m, "layerInput"); + + INDArray gradAtActivationOutput = grads.get("layerInput"); + INDArray gradAtInput = activationFn.backprop(preOutput.dup(), gradAtActivationOutput).getFirst(); + + if (mask != null) { + LossUtil.applyMask(gradAtInput, mask); + } + return gradAtInput; + } + + /** + * Compute both the score (loss function value) and gradient. This is equivalent to calling {@link #computeScore(INDArray, INDArray, IActivation, INDArray, boolean)} + * and {@link #computeGradient(INDArray, INDArray, IActivation, INDArray)} individually + * + * @param labels Label/expected output + * @param preOutput Output of the model (neural network) + * @param activationFn Activation function that should be applied to preOutput + * @param mask Mask array; may be null + * @param average Whether the score should be averaged (divided by number of rows in labels/output) or not + * @return The score (loss function value) and gradient + */ + @Override + public Pair computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, + INDArray mask, boolean average) { + + Pair GradientAndScore = new Pair<>(); + GradientAndScore.setFirst(this.computeScore(labels, preOutput, activationFn, mask, average)); + GradientAndScore.setSecond(this.computeGradient(labels, preOutput, activationFn, mask)); + + return GradientAndScore; + } + + @Override + public String name() { + return getClass().getSimpleName(); + } +} + + + + diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java index e9b021955..0c43ff9ca 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/workspace/BasicWorkspaceTests.java @@ -683,7 +683,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { workspace.initializeWorkspace(); - long reqMemory = 12 * Nd4j.sizeOfDataType(arrayCold.dataType()); + long reqMemory = 11 * Nd4j.sizeOfDataType(arrayCold.dataType()); assertEquals(reqMemory + reqMemory % 8, workspace.getCurrentSize()); From 0eca33ad94030d5bbc6779e03107a60910125a0f Mon Sep 17 00:00:00 2001 From: shugeo Date: Fri, 17 Apr 2020 16:52:08 +0300 Subject: [PATCH 10/18] Shugeo cuda solver fix (#383) * Refactored cuSolver handle usage to handle LaunchContext instance properly. Signed-off-by: shugeo * Refactored svd solver usage with LaunchContext instance singleton. Signed-off-by: shugeo * add device locks for cuSolver uses Signed-off-by: raver119 Co-authored-by: raver119 --- .../ops/declarable/helpers/cuda/lup.cu | 28 ++++++----- .../ops/declarable/helpers/cuda/svd.cu | 48 ++++++++++--------- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu index 2ca731912..c986260e8 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lup.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lup.cu @@ -341,14 +341,16 @@ namespace helpers { static void lup_(LaunchContext *context, NDArray *input, NDArray *compound, NDArray *permutation) { auto stream = context->getCudaStream(); auto n = input->rows(); - cusolverDnHandle_t cusolverH = nullptr; + std::lock_guard lock(*LaunchContext::deviceMutex()); + + cusolverDnHandle_t* cusolverH = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr; // create solver handle - cusolverStatus_t status = cusolverDnCreate(&cusolverH); - if (CUSOLVER_STATUS_SUCCESS != status) { - throw cuda_exception::build("Cannot create cuSolver handle", status); - } + cusolverStatus_t status; //cusolverDnCreate(&cusolverH); +// if (CUSOLVER_STATUS_SUCCESS != status) { +// throw cuda_exception::build("Cannot create cuSolver handle", status); +// } // set solver stream - status = cusolverDnSetStream(cusolverH, *stream); + status = cusolverDnSetStream(*cusolverH, *stream); if (CUSOLVER_STATUS_SUCCESS != status) { throw cuda_exception::build("Cannot set up stream for cuda solver", status); } @@ -368,7 +370,7 @@ namespace helpers { // compute internal buffer size double *matrix = reinterpret_cast(input->specialBuffer()); status = cusolverDnDgetrf_bufferSize( - cusolverH, + *cusolverH, n, n, matrix, @@ -386,7 +388,7 @@ namespace helpers { if (permutation == nullptr) { status = cusolverDnDgetrf( - cusolverH, + *cusolverH, n, n, matrix, @@ -404,7 +406,7 @@ namespace helpers { NDArray permutVector('c', {n}, sd::DataType::INT32, context); int* permutationBuf = permutVector.dataBuffer()->specialAsT(); status = cusolverDnDgetrf( - cusolverH, + *cusolverH, n, n, matrix, @@ -440,7 +442,7 @@ namespace helpers { float *d_work = nullptr; status = cusolverDnSgetrf_bufferSize( - cusolverH, + *cusolverH, n, n, matrix, @@ -458,7 +460,7 @@ namespace helpers { if (permutation == nullptr) status = cusolverDnSgetrf( - cusolverH, + *cusolverH, n, n, matrix, @@ -470,7 +472,7 @@ namespace helpers { NDArray permutVector('c', {n}, DataType::INT32, context); int *permutationBuf = reinterpret_cast(permutVector.specialBuffer()); status = cusolverDnSgetrf( - cusolverH, + *cusolverH, n, n, matrix, @@ -504,7 +506,7 @@ namespace helpers { if (err) { throw cuda_exception::build("helpers::lup_: Cannot deallocate memory for solver info buffer", err); } - cusolverDnDestroy(cusolverH); +// cusolverDnDestroy(cusolverH); // NDArray::registerSpecialUse({input}, {input}); input->tickWriteDevice(); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu index 44f924bf0..5c3d2811c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/svd.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/svd.cu @@ -170,23 +170,25 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr } } + std::lock_guard lock(*LaunchContext::deviceMutex()); + // create cusolverDn handle - cusolverDnHandle_t handle = nullptr; - cusolverStatus_t status = cusolverDnCreate(&handle); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdQR: cuda failed !", status); + cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); //nullptr; + //cusolverStatus_t status = cusolverDnCreate(&handle); + if(handle == nullptr) + throw cuda_exception::build("svdQR: cuda failed !", -1); // stream - status = cusolverDnSetStream(handle, *context->getCudaStream()); + auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); if(status != CUSOLVER_STATUS_SUCCESS) throw cuda_exception::build("svdQR: cuda failed !", status); // query working space of SVD int lwork = 0; if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvd_bufferSize(handle, m, n, &lwork); + status = cusolverDnDgesvd_bufferSize(*handle, m, n, &lwork); else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvd_bufferSize(handle, m, n, &lwork); + status = cusolverDnSgesvd_bufferSize(*handle, m, n, &lwork); else throw std::invalid_argument("svdQR: given data type is unsupported !"); @@ -227,10 +229,10 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr // choose appropriate cuda gemm api depending on data types if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); + status = cusolverDnDgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); } else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvd(handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); + status = cusolverDnSgesvd(*handle, jobu, jobvt, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : nullptr, ldu, calcUV ? reinterpret_cast(pVT->getSpecialBuffer()) : nullptr, ldvt, reinterpret_cast(dWork), lwork, reinterpret_cast(rWork), devInfo); } else throw std::invalid_argument("svdQR: given data type is unsupported !"); @@ -259,8 +261,8 @@ static void svdQR(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDAr if (rWork) cudaFree(rWork); - if(handle) - cusolverDnDestroy(handle); +// if(handle) +// cusolverDnDestroy(handle); // cudaDeviceReset(); } @@ -346,14 +348,16 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA ldv = pV->strideAt(1); } + std::lock_guard lock(*LaunchContext::deviceMutex()); + // create cusolverDn handle - cusolverDnHandle_t handle = nullptr; - cusolverStatus_t status = cusolverDnCreate(&handle); - if(status != CUSOLVER_STATUS_SUCCESS) - throw cuda_exception::build("svdJcb: cuda failed !", status); + cusolverDnHandle_t* handle = (cusolverDnHandle_t*)context->getCusolverHandle(); + //cusolverStatus_t status = cusolverDnCreate(&handle); + if(handle == nullptr) + throw cuda_exception::build("svdJcb: cuda failed !", -1); // stream - status = cusolverDnSetStream(handle, *context->getCudaStream()); + auto status = cusolverDnSetStream(*handle, *context->getCudaStream()); if(status != CUSOLVER_STATUS_SUCCESS) throw cuda_exception::build("svdJcb: cuda failed !", status); @@ -391,9 +395,9 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA // query working space of SVD int lwork = 0; if(A->dataType() == DataType::DOUBLE) - status = cusolverDnDgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); + status = cusolverDnDgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); else if(A->dataType() == DataType::FLOAT32) - status = cusolverDnSgesvdj_bufferSize(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); + status = cusolverDnSgesvdj_bufferSize(*handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, &lwork, gesvdjParams); else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -410,10 +414,10 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA // choose appropriate cuda gemm api depending on data types if(A->dataType() == DataType::DOUBLE) { - status = cusolverDnDgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + status = cusolverDnDgesvdj(*handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } else if(A->dataType() == DataType::FLOAT32) { - status = cusolverDnSgesvdj(handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); + status = cusolverDnSgesvdj(*handle, jobz, econ, m, n, reinterpret_cast(pA->getSpecialBuffer()), lda, reinterpret_cast(pS->getSpecialBuffer()), calcUV ? reinterpret_cast(pU->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldu, calcUV ? reinterpret_cast(pV->getSpecialBuffer()) : reinterpret_cast(nullPtr), ldv, reinterpret_cast(dWork), lwork, devInfo, gesvdjParams); } else throw std::invalid_argument("svdJcb: given data type is unsupported !"); @@ -446,8 +450,8 @@ static void svdJcb(sd::LaunchContext* context, const NDArray* A, NDArray* S, NDA cudaFree(devInfo); if (dWork ) cudaFree(dWork); - if(handle) - cusolverDnDestroy(handle); +// if(handle) +// cusolverDnDestroy(handle); if(gesvdjParams) cusolverDnDestroyGesvdjInfo(gesvdjParams); From 163222e85315a907d6ab9ba9ae33f5ac9127c412 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sun, 19 Apr 2020 13:34:26 +0300 Subject: [PATCH 11/18] tick device write in RandomLauncher Signed-off-by: raver119 --- .../include/helpers/impl/RandomLauncher.cpp | 42 ++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/helpers/impl/RandomLauncher.cpp b/libnd4j/include/helpers/impl/RandomLauncher.cpp index 8114c2ec4..f7cdd0f3a 100644 --- a/libnd4j/include/helpers/impl/RandomLauncher.cpp +++ b/libnd4j/include/helpers/impl/RandomLauncher.cpp @@ -26,8 +26,6 @@ #include namespace sd { - // FIXME: implement this - void RandomLauncher::applyDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { if (z == nullptr) z = array; @@ -35,8 +33,12 @@ namespace sd { ExtraArguments arguments({retainProb}); PointersManager pm(context, "applyDropOut"); + NDArray::prepareSpecialUse({z}, {array}); + NativeOpExecutioner::execRandom(context, random::DropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({z}, {array}); } void RandomLauncher::applyInvertedDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, NDArray* z) { @@ -46,8 +48,12 @@ namespace sd { ExtraArguments arguments({retainProb}); PointersManager pm(context, "applyInvertedDropOut"); + NDArray::prepareSpecialUse({z}, {array}); + NativeOpExecutioner::execRandom(context, random::DropOutInverted, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({z}, {array}); } void RandomLauncher::applyAlphaDropOut(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray *array, double retainProb, double alpha, double beta, double alphaPrime, NDArray* z) { @@ -57,63 +63,95 @@ namespace sd { ExtraArguments arguments({retainProb, alpha, beta, alphaPrime}); PointersManager pm(context, "applyAlphaDropOut"); + NDArray::prepareSpecialUse({z}, {array}); + NativeOpExecutioner::execRandom(context, random::AlphaDropOut, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), z->buffer(), z->shapeInfo(), z->specialBuffer(), z->specialShapeInfo(), arguments.argumentsAsT(z->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({z}, {array}); } void RandomLauncher::fillBernoulli(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double prob) { ExtraArguments arguments({prob}); PointersManager pm(context, "fillBernoulli"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::BernoulliDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillUniform(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double from, double to) { ExtraArguments arguments({from, to}); PointersManager pm(context, "fillUniform"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::UniformDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillGaussian(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { ExtraArguments arguments({mean, stdev}); PointersManager pm(context, "fillGaussian"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillExponential(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double lambda) { ExtraArguments arguments({lambda}); PointersManager pm(context, "fillExponential"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::ExponentialDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillLogNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { ExtraArguments arguments({mean, stdev}); PointersManager pm(context, "fillLogNormal"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillTruncatedNormal(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, double mean, double stdev) { ExtraArguments arguments({mean, stdev}); PointersManager pm(context, "fillTruncatedNormal"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::TruncatedNormalDistribution, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } void RandomLauncher::fillBinomial(sd::LaunchContext *context, sd::graph::RandomGenerator& rng, NDArray* array, int trials, double prob) { ExtraArguments arguments({(double) trials, prob}); PointersManager pm(context, "fillBinomial"); + NDArray::prepareSpecialUse({array}, {}); + NativeOpExecutioner::execRandom(context, random::BinomialDistributionEx, &rng, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); pm.synchronize(); + + NDArray::registerSpecialUse({array}, {}); } } From 73aa760c0f91596452c27548183851606a1f504f Mon Sep 17 00:00:00 2001 From: Alex Black Date: Mon, 20 Apr 2020 10:26:00 +1000 Subject: [PATCH 12/18] Timeouts and scala 2.12 for deeplearning4j-nlp-korean workaround (#380) * Increase default timeout on Spark tests Signed-off-by: Alex Black * #8840 disable deeplearning4j-nlp-korean module for scala 2.12 Signed-off-by: Alex Black * Fix for change-scala-versions.sh Signed-off-by: Alex Black * CUDA test fixes + more timeout issues Signed-off-by: Alex Black * More CUDA Signed-off-by: Alex Black * Small fix for cuDNN subsampling + same mode Signed-off-by: Alex Black * Flaky test fix Signed-off-by: Alex Black * Reduce memory requirements for ValidateCuDNN BN test Signed-off-by: Alex Black * Fix slow/ineffirient scalnet tests Signed-off-by: Alex Black * Increase timeouts to avoid failures if CI machines are slower than expected Signed-off-by: Alex Black * Ignore flaky test (issue #8849) and increase timeout for slow CI downloads Signed-off-by: Alex Black --- change-scala-versions.sh | 10 +++++ .../java/org/deeplearning4j/TestUtils.java | 13 +++++- .../iterator/DataSetIteratorTest.java | 2 +- .../gradientcheck/AttentionLayerTest.java | 5 +++ .../gradientcheck/BNGradientCheckTest.java | 5 +++ .../gradientcheck/CNN1DGradientCheckTest.java | 5 +++ .../gradientcheck/CNN3DGradientCheckTest.java | 5 +++ .../gradientcheck/CNNGradientCheckTest.java | 5 +++ .../CapsnetGradientCheckTest.java | 5 +++ .../gradientcheck/DropoutGradientCheck.java | 5 +++ .../GlobalPoolingGradientCheckTests.java | 5 +++ .../gradientcheck/GradientCheckTests.java | 5 +++ .../GradientCheckTestsComputationGraph.java | 5 +++ .../gradientcheck/LRNGradientCheckTests.java | 5 +++ .../gradientcheck/LSTMGradientCheckTests.java | 5 +++ .../LossFunctionGradientCheck.java | 5 +++ .../NoBiasGradientCheckTests.java | 5 +++ .../OutputLayerGradientChecks.java | 5 +++ .../gradientcheck/RnnGradientChecks.java | 5 +++ .../UtilLayerGradientChecks.java | 5 +++ .../gradientcheck/VaeGradientCheckTests.java | 5 +++ .../gradientcheck/YoloGradientCheckTests.java | 5 +++ .../TransferLearningMLNTest.java | 15 ++++--- .../org/deeplearning4j/TestDataTypes.java | 44 +++++++++---------- .../org/deeplearning4j/ValidateCuDNN.java | 25 +++++++---- .../convolution/TestConvolution.java | 31 ++++++++----- .../gradientcheck/CNNGradientCheckTest.java | 5 +++ .../gradientcheck/CuDNNGradientChecks.java | 4 ++ .../lstm/ValidateCudnnDropout.java | 5 +++ .../lstm/ValidateCudnnLSTM.java | 7 ++- .../keras/e2e/KerasModelEndToEndTest.java | 2 +- .../subsampling/SubsamplingLayer.java | 4 +- .../SparkSequenceVectorsTest.java | 6 +++ .../models/word2vec/SparkWord2VecTest.java | 6 +++ .../spark/text/BaseSparkTest.java | 5 +++ .../spark/text/TextPipelineTest.java | 5 ++- .../spark/parameterserver/BaseSparkTest.java | 6 +++ .../spark/BaseSparkKryoTest.java | 5 +++ .../deeplearning4j/spark/BaseSparkTest.java | 5 +++ .../dl4j/feedforward/IrisCSVExample.scala | 6 ++- .../dl4j/recurrent/BasicRNNExample.scala | 2 +- 41 files changed, 247 insertions(+), 61 deletions(-) diff --git a/change-scala-versions.sh b/change-scala-versions.sh index 8968abbf3..aace1b05e 100755 --- a/change-scala-versions.sh +++ b/change-scala-versions.sh @@ -88,5 +88,15 @@ find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ #Scala maven plugin, 2.11 find "$BASEDIR" -name 'pom.xml' -not -path '*target*' \ -exec bash -c "sed_i 's/\(scalaVersion>\)'$FROM_VERSION'<\/scalaVersion>/\1'$TO_VERSION'<\/scalaVersion>/g' {}" \; + +# Disable deeplearning4j-nlp-korean for scala 2.12 - see https://github.com/eclipse/deeplearning4j/issues/8840 +if [ $TO_VERSION = $SCALA_211_VERSION ]; then + #Enable + sed -i 's/ / deeplearning4j-nlp-korean<\/module>/g' deeplearning4j/deeplearning4j-nlp-parent/pom.xml +else + #Disable + sed -i 's/ deeplearning4j-nlp-korean<\/module>/ /g' deeplearning4j/deeplearning4j-nlp-parent/pom.xml +fi + echo "Done updating Scala versions."; diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java index d90ce628b..d54693f73 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/TestUtils.java @@ -31,6 +31,7 @@ import org.deeplearning4j.nn.layers.recurrent.LSTM; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.util.ModelSerializer; import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.factory.Nd4j; @@ -124,12 +125,20 @@ public class TestUtils { return randomOneHot(examples, nOut, new Random(12345)); } + public static INDArray randomOneHot(DataType dataType, long examples, long nOut){ + return randomOneHot(dataType, examples, nOut, new Random(12345)); + } + public static INDArray randomOneHot(long examples, long nOut, long rngSeed){ return randomOneHot(examples, nOut, new Random(rngSeed)); } - public static INDArray randomOneHot(long examples, long nOut, Random rng){ - INDArray arr = Nd4j.create(examples, nOut); + public static INDArray randomOneHot(long examples, long nOut, Random rng) { + return randomOneHot(Nd4j.defaultFloatingPointType(), examples,nOut, rng); + } + + public static INDArray randomOneHot(DataType dataType, long examples, long nOut, Random rng){ + INDArray arr = Nd4j.create(dataType, examples, nOut); for( int i=0; i> classesToTest = new ArrayList<>(); classesToTest.add(org.deeplearning4j.nn.layers.normalization.BatchNormalization.class); @@ -185,10 +191,11 @@ public class ValidateCuDNN extends BaseDL4JTest { //Test ONLY LRN - no other CuDNN functionality (i.e., DL4J impls for everything else) Nd4j.getRandom().setSeed(12345); + int minibatch = 8; int numClasses = 10; //imageHeight,imageWidth,channels - int imageHeight = 240; - int imageWidth = 240; + int imageHeight = 48; + int imageWidth = 48; int channels = 3; IActivation activation = new ActivationIdentity(); MultiLayerConfiguration multiLayerConfiguration = new NeuralNetConfiguration.Builder() @@ -229,8 +236,8 @@ public class ValidateCuDNN extends BaseDL4JTest { MultiLayerNetwork net = new MultiLayerNetwork(multiLayerConfiguration); net.init(); - int[] fShape = new int[]{32, channels, imageHeight, imageWidth}; - int[] lShape = new int[]{32, numClasses}; + int[] fShape = new int[]{minibatch, channels, imageHeight, imageWidth}; + int[] lShape = new int[]{minibatch, numClasses}; List> classesToTest = new ArrayList<>(); classesToTest.add(org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization.class); diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java index c89532ff5..67a2958b7 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/convolution/TestConvolution.java @@ -70,6 +70,11 @@ public class TestConvolution extends BaseDL4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Override + public long getTimeoutMilliseconds() { + return 240000L; + } + @Test public void testSameModeActivationSizes() { int inH = 3; @@ -117,6 +122,8 @@ public class TestConvolution extends BaseDL4JTest { for (ConvolutionMode c : cm) { for (ConvolutionLayer.AlgoMode a : new ConvolutionLayer.AlgoMode[]{ConvolutionLayer.AlgoMode.NO_WORKSPACE, ConvolutionLayer.AlgoMode.PREFER_FASTEST}) { for (boolean conv : new boolean[]{true, false}) { + String msg = c + " - " + a + " - " + (conv ? "conv" : "subsampling"); + System.out.println(msg); org.deeplearning4j.nn.conf.layers.Layer l; if (conv) { @@ -125,7 +132,9 @@ public class TestConvolution extends BaseDL4JTest { l = new SubsamplingLayer.Builder().kernelSize(4, 4).stride(2, 2).build(); } - MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345) + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .dataType(DataType.DOUBLE) + .seed(12345) .l2(0.0005).updater(new Sgd(0.01)).weightInit(WeightInit.XAVIER).convolutionMode(c).cudnnAlgoMode(a).list() .layer(0, l) .layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) @@ -159,32 +168,32 @@ public class TestConvolution extends BaseDL4JTest { throw new RuntimeException(); - INDArray in = Nd4j.rand(new int[]{1, 1, 20, 20}); //(20-4+0)/2 +1 = 9 + INDArray in = Nd4j.rand(DataType.DOUBLE, new int[]{1, 1, 20, 20}); //(20-4+0)/2 +1 = 9 INDArray outCudnn = layerCudnn.activate(in, false, LayerWorkspaceMgr.noWorkspaces()); INDArray outStd = layerStandard.activate(in, false, LayerWorkspaceMgr.noWorkspaces()); - assertEquals(outStd, outCudnn); + assertEquals(msg, outStd, outCudnn); //Check backprop: - INDArray epsilon = Nd4j.rand(outStd.shape()); - Pair pCudnn = layerCudnn.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); - Pair pStd = layerStandard.backpropGradient(epsilon, LayerWorkspaceMgr.noWorkspaces()); + INDArray epsilon = Nd4j.rand(DataType.DOUBLE, outStd.shape()); + Pair pCudnn = layerCudnn.backpropGradient(epsilon.dup(), LayerWorkspaceMgr.noWorkspaces()); + Pair pStd = layerStandard.backpropGradient(epsilon.dup(), LayerWorkspaceMgr.noWorkspaces()); - System.out.println(Arrays.toString(pStd.getSecond().data().asFloat())); - System.out.println(Arrays.toString(pCudnn.getSecond().data().asFloat())); +// System.out.println(Arrays.toString(pStd.getSecond().data().asFloat())); +// System.out.println(Arrays.toString(pCudnn.getSecond().data().asFloat())); INDArray epsOutStd = pStd.getSecond(); INDArray epsOutCudnn = pCudnn.getSecond(); - assertTrue(epsOutStd.equalsWithEps(epsOutCudnn, 1e-4)); + assertTrue(msg, epsOutStd.equalsWithEps(epsOutCudnn, 1e-4)); if (conv) { INDArray gradStd = pStd.getFirst().gradient(); INDArray gradCudnn = pCudnn.getFirst().gradient(); - assertTrue(gradStd.equalsWithEps(gradCudnn, 1e-4)); + assertTrue(msg, gradStd.equalsWithEps(gradCudnn, 1e-4)); } } } @@ -192,7 +201,7 @@ public class TestConvolution extends BaseDL4JTest { } - @Test @Ignore //AB 2019/05/21 - Ignored to get master passing - issue logged here: https://github.com/deeplearning4j/deeplearning4j/issues/7766 + @Test public void validateXceptionImport() throws Exception { File dir = testDir.newFolder(); File fSource = Resources.asFile("modelimport/keras/examples/xception/xception_tf_keras_2.h5"); diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java index 1b8c42e14..eb06a70ae 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CNNGradientCheckTest.java @@ -61,6 +61,11 @@ public class CNNGradientCheckTest extends BaseDL4JTest { Nd4j.setDataType(DataType.DOUBLE); } + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Test public void testGradientCNNMLN() { //Parameterized test, testing combinations of: diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java index 9e43f042b..a2ab8236f 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java @@ -77,6 +77,10 @@ public class CuDNNGradientChecks extends BaseDL4JTest { DataTypeUtil.setDTypeForContext(DataType.DOUBLE); } + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } @Test public void testConvolutional() throws Exception { diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnDropout.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnDropout.java index 3edf564b5..c46bb99d9 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnDropout.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnDropout.java @@ -30,6 +30,11 @@ import static org.junit.Assert.assertTrue; public class ValidateCudnnDropout extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Test public void testCudnnDropoutSimple() { for (int[] shape : new int[][]{{10, 10}, {5, 2, 5, 2}}) { diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnLSTM.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnLSTM.java index 08b57aa65..6bbb934a5 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnLSTM.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/lstm/ValidateCudnnLSTM.java @@ -46,6 +46,11 @@ import static org.junit.Assert.*; */ public class ValidateCudnnLSTM extends BaseDL4JTest { + @Override + public long getTimeoutMilliseconds() { + return 180000L; + } + @Test public void validateImplSimple() throws Exception { @@ -109,7 +114,7 @@ public class ValidateCudnnLSTM extends BaseDL4JTest { mln1.computeGradientAndScore(); mln2.computeGradientAndScore(); - assertEquals(mln1.score(), mln2.score(), 1e-8); + assertEquals(mln1.score(), mln2.score(), 1e-5); Gradient g1 = mln1.gradient(); Gradient g2 = mln2.gradient(); diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java index 7538d39bc..3e1efa365 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/e2e/KerasModelEndToEndTest.java @@ -98,7 +98,7 @@ public class KerasModelEndToEndTest extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 90000L; + return 180000L; //Most benchmarks should run very quickly; large timeout is to avoid issues with unusually slow download of test resources } @Test(expected = IllegalStateException.class) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java index b3a7a3c12..b38945e95 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/convolution/subsampling/SubsamplingLayer.java @@ -116,10 +116,10 @@ public class SubsamplingLayer extends AbstractLayer> sequencesCyclic; private JavaSparkContext sc; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java index f3b3f974a..a7bdfd45b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp-java8/src/test/java/org/deeplearning4j/spark/models/word2vec/SparkWord2VecTest.java @@ -48,6 +48,12 @@ import static org.junit.Assert.*; * @author raver119@gmail.com */ public class SparkWord2VecTest extends BaseDL4JTest { + + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + private static List sentences; private JavaSparkContext sc; diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java index 475572edd..af39a474c 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/BaseSparkTest.java @@ -34,6 +34,11 @@ import java.util.Map; public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable { protected transient JavaSparkContext sc; + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Before public void before() throws Exception { sc = getContext(); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java index 63c84de7d..9e5ad1d67 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java @@ -31,6 +31,7 @@ import org.deeplearning4j.spark.text.functions.CountCumSum; import org.deeplearning4j.spark.text.functions.TextPipeline; import org.deeplearning4j.text.stopwords.StopWords; import org.junit.Before; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.primitives.Counter; @@ -350,7 +351,7 @@ public class TextPipelineTest extends BaseSparkTest { * * @throws Exception */ - @Test + @Test @Ignore //AB 2020/04/19 https://github.com/eclipse/deeplearning4j/issues/8849 public void testZipFunction1() throws Exception { JavaSparkContext sc = getContext(); JavaRDD corpusRDD = getCorpusRDD(sc); @@ -388,7 +389,7 @@ public class TextPipelineTest extends BaseSparkTest { sc.stop(); } - @Test + @Test @Ignore //AB 2020/04/19 https://github.com/eclipse/deeplearning4j/issues/8849 public void testZipFunction2() throws Exception { JavaSparkContext sc = getContext(); JavaRDD corpusRDD = getCorpusRDD(sc); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java index 50aa564c1..9a28fe351 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/BaseSparkTest.java @@ -53,6 +53,12 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable protected transient DataSet data; protected transient JavaRDD sparkData; + + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Before public void before() { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java index 42fa57a37..f7b4da172 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkKryoTest.java @@ -28,6 +28,11 @@ import java.util.Map; */ public class BaseSparkKryoTest extends BaseSparkTest { + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Override public JavaSparkContext getContext() { if (sc != null) { diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java index 3d1a9755a..be78ec7cd 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/test/java/org/deeplearning4j/spark/BaseSparkTest.java @@ -55,6 +55,11 @@ public abstract class BaseSparkTest extends BaseDL4JTest implements Serializable protected transient DataSet data; protected transient JavaRDD sparkData; + @Override + public long getTimeoutMilliseconds() { + return 120000L; + } + @Before public void before() { diff --git a/scalnet/src/test/scala/org/deeplearning4j/scalnet/examples/dl4j/feedforward/IrisCSVExample.scala b/scalnet/src/test/scala/org/deeplearning4j/scalnet/examples/dl4j/feedforward/IrisCSVExample.scala index 3363c8a29..5f6670364 100644 --- a/scalnet/src/test/scala/org/deeplearning4j/scalnet/examples/dl4j/feedforward/IrisCSVExample.scala +++ b/scalnet/src/test/scala/org/deeplearning4j/scalnet/examples/dl4j/feedforward/IrisCSVExample.scala @@ -21,6 +21,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader import org.datavec.api.split.FileSplit import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator +import org.deeplearning4j.nn.conf.Updater import org.deeplearning4j.optimize.listeners.ScoreIterationListener import org.deeplearning4j.scalnet.layers.core.Dense import org.deeplearning4j.scalnet.logging.Logging @@ -29,6 +30,7 @@ import org.nd4j.linalg.activations.Activation import org.nd4j.linalg.dataset.api.iterator.DataSetIterator import org.nd4j.linalg.dataset.{ DataSet, SplitTestAndTrain } import org.nd4j.linalg.io.ClassPathResource +import org.nd4j.linalg.learning.config.Adam import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction object IrisCSVExample extends App with Logging { @@ -41,7 +43,7 @@ object IrisCSVExample extends App with Logging { val hiddenSize = 128 val inputSize = 4 val outputSize = 3 - val epochs = 1000 + val epochs = 20 val scoreFrequency = 5 val seed = 1234 @@ -64,7 +66,7 @@ object IrisCSVExample extends App with Logging { model.add(Dense(nOut = hiddenSize, activation = Activation.RELU)) model.add(Dense(nOut = hiddenSize, activation = Activation.RELU)) model.add(Dense(outputSize, activation = Activation.SOFTMAX)) - model.compile(LossFunction.MCXENT) + model.compile(LossFunction.MCXENT, updater = Updater.ADAM) logger.info("Train model...") model.fit(training_data, epochs, List(new ScoreIterationListener(scoreFrequency))) diff --git a/scalnet/src/test/scala/org/deeplearning4j/scalnet/examples/dl4j/recurrent/BasicRNNExample.scala b/scalnet/src/test/scala/org/deeplearning4j/scalnet/examples/dl4j/recurrent/BasicRNNExample.scala index 605ce1e56..dc2b418f1 100644 --- a/scalnet/src/test/scala/org/deeplearning4j/scalnet/examples/dl4j/recurrent/BasicRNNExample.scala +++ b/scalnet/src/test/scala/org/deeplearning4j/scalnet/examples/dl4j/recurrent/BasicRNNExample.scala @@ -34,7 +34,7 @@ object BasicRNNExample extends App with Logging { val learningString = "*Der Cottbuser Postkutscher putzt den Cottbuser Postkutschkasten.".toVector val learningChars = learningString.distinct val hiddenSize = 64 - val epochs = 200 + val epochs = 20 val seed = 1234 val rand = new Random(seed) From 455a5d112d2189f08f038752cf15b47ab8001f84 Mon Sep 17 00:00:00 2001 From: Alexander Stoyakin Date: Mon, 20 Apr 2020 03:27:13 +0300 Subject: [PATCH 13/18] Fixes for codegen generated classes and build improvements (#367) * Input format extended * Deleted redundant code * Added weights format to conv2d config * Refactoring * dl4j base test functionality * Different tests base class per module * Check base class for dl4j-graph subproject tests * Check if test classes extend BaseDL4JTest * Use nd4j-common-tests as transient dependency * Enums and tests added * Added codegenerated methods * Use namespace methods * Replace DifferentialFunctionFactory with codegenerated classes * Fixed linspace * Namespaces regenerated * Namespaces used instead of factory * Regenerated base classes * Input format extended * Added weights format to conv2d config * Refactoring * dl4j base test functionality * Different tests base class per module * Check base class for dl4j-graph subproject tests * Check if test classes extend BaseDL4JTest * Use nd4j-common-tests as transient dependency * Enums and tests added * Added codegenerated methods * Use namespace methods * Replace DifferentialFunctionFactory with codegenerated classes * Fixed linspace * Namespaces regenerated * Regenerated base classes * Regenerated namespaces * Generate nd4j namespaces * INDArrays accepting constructors * Generated some ops * Some fixes * SameDiff ops regenerated * Regenerated nd4j ops * externalErrors moved * Compilation fixes * SquaredDifference - strict number of args * Deprecated code cleanup. Proper base class for tests. * Extend test classes with BaseND4JTest * Extend test classes with BaseDL4JTest * Legacy code * DL4J cleanup * Exclude test utils from base class check * Tests fixed * Arbiter tests fix * Test dependency scope fix + pom.xml formatting Signed-off-by: Alex Black * Significant number of fixes Signed-off-by: Alex Black * Another round of fixes Signed-off-by: Alex Black * Another round of fixes Signed-off-by: Alex Black * Few additional fixes Signed-off-by: Alex Black * DataVec missing test scope dependencies Signed-off-by: Alex Black Co-authored-by: Alex Black --- arbiter/arbiter-core/pom.xml | 11 +- .../optimize/AssertTestsExtendBaseClass.java | 49 + .../arbiter/AssertTestsExtendBaseClass.java | 50 + .../server/AssertTestsExtendBaseClass.java | 50 + .../server/MnistDataSetIteratorFactory.java | 3 +- .../server/TestDataFactoryProviderMnist.java | 3 +- arbiter/arbiter-ui/pom.xml | 7 + .../optimize/AssertTestsExtendBaseClass.java | 50 + .../arbiter/optimize/TestBasic.java | 3 +- datavec/datavec-api/pom.xml | 11 +- .../api/util/ndarray/RecordConverter.java | 31 +- .../api/AssertTestsExtendBaseClass.java | 57 + .../impl/CSVLineSequenceRecordReaderTest.java | 3 +- .../CSVMultiSequenceRecordReaderTest.java | 4 +- .../CSVNLinesSequenceRecordReaderTest.java | 3 +- .../reader/impl/CSVRecordReaderTest.java | 3 +- .../impl/CSVSequenceRecordReaderTest.java | 3 +- ...VariableSlidingWindowRecordReaderTest.java | 3 +- .../impl/FileBatchRecordReaderTest.java | 3 +- .../reader/impl/FileRecordReaderTest.java | 3 +- .../impl/JacksonLineRecordReaderTest.java | 3 +- .../reader/impl/JacksonRecordReaderTest.java | 3 +- .../reader/impl/LibSvmRecordReaderTest.java | 3 +- .../records/reader/impl/LineReaderTest.java | 3 +- .../reader/impl/RegexRecordReaderTest.java | 3 +- .../reader/impl/SVMLightRecordReaderTest.java | 3 +- .../impl/TestCollectionRecordReaders.java | 3 +- .../impl/TestConcatenatingRecordReader.java | 3 +- .../reader/impl/TestSerialization.java | 3 +- .../TransformProcessRecordReaderTests.java | 3 +- .../writer/impl/CSVRecordWriterTest.java | 3 +- .../writer/impl/LibSvmRecordWriterTest.java | 3 +- .../writer/impl/SVMLightRecordWriterTest.java | 3 +- .../datavec/api/split/InputSplitTests.java | 3 +- .../split/NumberedFileInputSplitTests.java | 3 +- .../api/split/TestStreamInputSplit.java | 3 +- .../datavec/api/split/TransformSplitTest.java | 3 +- .../api/split/parittion/PartitionerTests.java | 3 +- .../api/transform/TestTransformProcess.java | 3 +- .../transform/condition/TestConditions.java | 3 +- .../api/transform/filter/TestFilters.java | 3 +- .../datavec/api/transform/join/TestJoin.java | 3 +- .../transform/ops/AggregableMultiOpTest.java | 3 +- .../transform/ops/AggregatorImplsTest.java | 3 +- .../api/transform/ops/DispatchOpTest.java | 3 +- .../transform/reduce/TestMultiOpReduce.java | 3 +- .../api/transform/reduce/TestReductions.java | 3 +- .../api/transform/schema/TestJsonYaml.java | 3 +- .../transform/schema/TestSchemaMethods.java | 3 +- .../TestReduceSequenceByWindowFunction.java | 3 +- .../transform/sequence/TestSequenceSplit.java | 3 +- .../sequence/TestWindowFunctions.java | 3 +- .../serde/TestCustomTransformJsonYaml.java | 3 +- .../transform/serde/TestYamlJsonSerde.java | 3 +- .../transform/stringreduce/TestReduce.java | 3 +- .../transform/RegressionTestJson.java | 3 +- .../api/transform/transform/TestJsonYaml.java | 3 +- .../transform/transform/TestTransforms.java | 3 +- .../TestNDArrayWritableTransforms.java | 3 +- .../transform/ndarray/TestYamlJsonSerde.java | 3 +- .../parse/ParseDoubleTransformTest.java | 3 +- .../org/datavec/api/transform/ui/TestUI.java | 3 +- .../api/util/ClassPathResourceTest.java | 3 +- .../datavec/api/util/TimeSeriesUtilsTest.java | 3 +- .../api/writable/RecordConverterTest.java | 7 +- .../TestNDArrayWritableAndSerialization.java | 3 +- .../datavec/api/writable/WritableTest.java | 5 +- datavec/datavec-arrow/pom.xml | 6 + .../org/datavec/arrow/ArrowConverterTest.java | 3 +- .../arrow/AssertTestsExtendBaseClass.java | 50 + .../org/datavec/arrow/RecordMapperTest.java | 3 +- ...rowWritableRecordTimeSeriesBatchTests.java | 3 +- .../datavec-data/datavec-data-audio/pom.xml | 7 + .../audio/AssertTestsExtendBaseClass.java | 55 + .../org/datavec/audio/AudioReaderTest.java | 3 +- .../audio/TestFastFourierTransform.java | 3 +- .../datavec-data/datavec-data-codec/pom.xml | 7 + .../reader/AssertTestsExtendBaseClass.java | 50 + .../datavec-data/datavec-data-image/pom.xml | 7 + .../image/AssertTestsExtendBaseClass.java | 50 + datavec/datavec-data/datavec-data-nlp/pom.xml | 7 + .../nlp/AssertTestsExtendBaseClass.java | 50 + datavec/datavec-data/datavec-geo/pom.xml | 9 +- .../transform/AssertTestsExtendBaseClass.java | 49 + datavec/datavec-data/datavec-hadoop/pom.xml | 7 + .../hadoop/AssertTestsExtendBaseClass.java | 48 + datavec/datavec-excel/pom.xml | 7 + .../poi/excel/AssertTestsExtendBaseClass.java | 50 + datavec/datavec-jdbc/pom.xml | 7 + .../reader/AssertTestsExtendBaseClass.java | 49 + datavec/datavec-local/pom.xml | 7 + .../AssertTestsExtendBaseClass.java | 50 + .../transforms/analysis/TestAnalyzeLocal.java | 3 +- datavec/datavec-python/pom.xml | 7 + .../python/AssertTestsExtendBaseClass.java | 50 + .../datavec-spark-inference-client/pom.xml | 7 + .../client/AssertTestsExtendBaseClass.java | 49 + .../datavec-spark-inference-model/pom.xml | 7 + .../spark/transform/CSVSparkTransform.java | 3 +- .../transform/AssertTestsExtendBaseClass.java | 50 + .../datavec-spark-inference-server/pom.xml | 7 + .../transform/AssertTestsExtendBaseClass.java | 50 + datavec/datavec-spark/pom.xml | 6 + .../spark/AssertTestsExtendBaseClass.java | 50 + .../spark/transform/NormalizationTests.java | 6 +- .../deeplearning4j-common-tests/pom.xml | 5 + deeplearning4j/deeplearning4j-core/pom.xml | 14 - .../AssertTestsExtendBaseClass.java | 54 +- .../CompareTrainingImplementations.java | 2 +- .../RecordReaderMultiDataSetIterator.java | 2 +- deeplearning4j/deeplearning4j-graph/pom.xml | 1 - .../graph/AssertTestsExtendedBaseClass.java | 49 + .../tokenizer/AssertTestsExtendBaseClass.java | 52 + .../test/java/AssertTestsExtendBaseClass.java | 53 + .../test/java/AssertTestsExtendBaseClass.java | 49 + .../AssertTestsExtendBaseClass.java | 49 + .../deeplearning4j-nlp/pom.xml | 125 +- .../AssertTestsExtendBaseClass.java | 49 + .../nn/conf/layers/LocallyConnected1D.java | 5 +- .../nn/conf/layers/LocallyConnected2D.java | 5 +- .../layers/samediff/SameDiffGraphVertex.java | 3 +- .../nn/layers/samediff/SameDiffLayer.java | 3 +- .../remote/AssertTestsExtendBaseClass.java | 48 + .../functions/DifferentialFunction.java | 15 +- .../DifferentialFunctionFactory.java | 2659 ----------------- .../nd4j/autodiff/samediff/SDVariable.java | 59 +- .../org/nd4j/autodiff/samediff/SameDiff.java | 45 +- .../samediff/config/OutputConfig.java | 4 +- .../nd4j/autodiff/samediff/ops/SDBaseOps.java | 110 +- .../nd4j/autodiff/samediff/ops/SDImage.java | 8 +- .../nd4j/autodiff/samediff/ops/SDMath.java | 674 ++++- .../org/nd4j/autodiff/samediff/ops/SDNN.java | 66 +- .../org/nd4j/autodiff/samediff/ops/SDOps.java | 27 +- .../samediff/transform/OpPredicate.java | 1 - .../org/nd4j/autodiff/util/SameDiffUtils.java | 139 + .../org/nd4j/autodiff/util/TrainingUtils.java | 70 - .../src/main/java/org/nd4j/enums/PadMode.java | 29 + .../java/org/nd4j/enums/WeightsFormat.java | 29 + .../linalg/api/ops/BaseBroadcastBoolOp.java | 8 +- .../nd4j/linalg/api/ops/BaseBroadcastOp.java | 9 +- .../linalg/api/ops/BaseIndexAccumulation.java | 6 +- .../org/nd4j/linalg/api/ops/BaseReduceOp.java | 7 +- .../nd4j/linalg/api/ops/BaseScalarBoolOp.java | 3 +- .../org/nd4j/linalg/api/ops/BaseScalarOp.java | 3 +- .../nd4j/linalg/api/ops/BaseTransformOp.java | 11 +- .../java/org/nd4j/linalg/api/ops/NoOp.java | 5 + .../api/ops/impl/broadcast/BiasAdd.java | 2 +- .../impl/controlflow/compat/BaseCompatOp.java | 5 + .../ops/impl/controlflow/compat/Merge.java | 5 + .../impl/controlflow/compat/StopGradient.java | 2 +- .../ops/impl/controlflow/compat/Switch.java | 5 + .../linalg/api/ops/impl/indexaccum/IAMax.java | 2 +- .../linalg/api/ops/impl/indexaccum/IAMin.java | 2 +- .../linalg/api/ops/impl/indexaccum/IMax.java | 2 +- .../linalg/api/ops/impl/indexaccum/IMin.java | 2 +- .../ops/impl/layers/convolution/Conv2D.java | 6 +- .../ops/impl/layers/convolution/DeConv3D.java | 3 +- .../ops/impl/layers/convolution/Im2col.java | 2 +- .../impl/layers/convolution/Upsampling2d.java | 2 +- .../convolution/config/Conv2DConfig.java | 6 +- .../ops/impl/loss/AbsoluteDifferenceLoss.java | 4 +- .../api/ops/impl/loss/CosineDistanceLoss.java | 4 +- .../linalg/api/ops/impl/loss/HingeLoss.java | 4 +- .../linalg/api/ops/impl/loss/HuberLoss.java | 4 +- .../nd4j/linalg/api/ops/impl/loss/L2Loss.java | 2 +- .../linalg/api/ops/impl/loss/LogLoss.java | 4 +- .../api/ops/impl/loss/LogPoissonLoss.java | 10 +- .../loss/MeanPairwiseSquaredErrorLoss.java | 4 +- .../ops/impl/loss/MeanSquaredErrorLoss.java | 4 +- .../impl/loss/SigmoidCrossEntropyLoss.java | 4 +- .../impl/loss/SoftmaxCrossEntropyLoss.java | 4 +- .../SoftmaxCrossEntropyWithLogitsLoss.java | 5 +- ...arseSoftmaxCrossEntropyLossWithLogits.java | 6 +- .../nd4j/linalg/api/ops/impl/reduce/Mmul.java | 2 +- .../linalg/api/ops/impl/reduce/Moments.java | 6 +- .../api/ops/impl/reduce/TensorMmul.java | 18 +- .../linalg/api/ops/impl/reduce/bool/All.java | 2 +- .../linalg/api/ops/impl/reduce/bool/Any.java | 2 +- .../api/ops/impl/reduce/bool/IsInf.java | 2 +- .../api/ops/impl/reduce/bool/IsNaN.java | 2 +- .../api/ops/impl/reduce/custom/LogSumExp.java | 5 +- .../api/ops/impl/reduce/floating/AMean.java | 3 +- .../api/ops/impl/reduce/floating/Entropy.java | 11 +- .../ops/impl/reduce/floating/LogEntropy.java | 4 +- .../api/ops/impl/reduce/floating/Mean.java | 3 +- .../api/ops/impl/reduce/floating/Norm1.java | 3 +- .../api/ops/impl/reduce/floating/Norm2.java | 3 +- .../api/ops/impl/reduce/floating/NormMax.java | 3 +- .../impl/reduce/floating/ShannonEntropy.java | 7 +- .../ops/impl/reduce/floating/SquaredNorm.java | 3 +- .../ops/impl/reduce/longer/CountNonZero.java | 2 +- .../api/ops/impl/reduce/longer/CountZero.java | 2 +- .../linalg/api/ops/impl/reduce/same/AMax.java | 3 +- .../linalg/api/ops/impl/reduce/same/AMin.java | 3 +- .../linalg/api/ops/impl/reduce/same/ASum.java | 3 +- .../linalg/api/ops/impl/reduce/same/Max.java | 3 +- .../linalg/api/ops/impl/reduce/same/Min.java | 3 +- .../linalg/api/ops/impl/reduce/same/Prod.java | 3 +- .../linalg/api/ops/impl/reduce/same/Sum.java | 3 +- .../api/ops/impl/reduce3/CosineDistance.java | 4 +- .../ops/impl/reduce3/CosineSimilarity.java | 12 +- .../nd4j/linalg/api/ops/impl/reduce3/Dot.java | 3 +- .../ops/impl/reduce3/EuclideanDistance.java | 5 +- .../api/ops/impl/reduce3/JaccardDistance.java | 17 +- .../ops/impl/reduce3/ManhattanDistance.java | 5 +- .../linalg/api/ops/impl/scalar/LeakyReLU.java | 3 +- .../linalg/api/ops/impl/scalar/PRelu.java | 3 +- .../nd4j/linalg/api/ops/impl/scalar/Pow.java | 5 +- .../api/ops/impl/scalar/RectifiedLinear.java | 3 +- .../linalg/api/ops/impl/scalar/Relu6.java | 3 +- .../linalg/api/ops/impl/scalar/ScalarAdd.java | 6 +- .../impl/scalar/ScalarReverseDivision.java | 5 +- .../impl/scalar/ScalarReverseSubtraction.java | 5 +- .../linalg/api/ops/impl/scalar/ScalarSet.java | 2 +- .../nd4j/linalg/api/ops/impl/scalar/Step.java | 2 +- .../api/ops/impl/scatter/ScatterAdd.java | 4 +- .../api/ops/impl/scatter/ScatterDiv.java | 10 +- .../api/ops/impl/scatter/ScatterMax.java | 6 +- .../api/ops/impl/scatter/ScatterMin.java | 6 +- .../api/ops/impl/scatter/ScatterMul.java | 8 +- .../api/ops/impl/scatter/ScatterSub.java | 4 +- .../api/ops/impl/scatter/ScatterUpdate.java | 8 +- .../linalg/api/ops/impl/shape/Linspace.java | 2 +- .../linalg/api/ops/impl/shape/Permute.java | 4 +- .../linalg/api/ops/impl/shape/Reshape.java | 4 +- .../api/ops/impl/shape/SequenceMask.java | 2 +- .../linalg/api/ops/impl/shape/ShapeN.java | 2 +- .../nd4j/linalg/api/ops/impl/shape/Slice.java | 5 +- .../nd4j/linalg/api/ops/impl/shape/Stack.java | 2 +- .../api/ops/impl/shape/StridedSlice.java | 9 +- .../nd4j/linalg/api/ops/impl/shape/Tile.java | 5 +- .../impl/summarystats/StandardDeviation.java | 3 +- .../api/ops/impl/summarystats/Variance.java | 3 +- .../linalg/api/ops/impl/transforms/Angle.java | 2 +- .../linalg/api/ops/impl/transforms/Pad.java | 27 + .../api/ops/impl/transforms/any/IsMax.java | 2 +- .../ops/impl/transforms/bool/BooleanNot.java | 2 +- .../ops/impl/transforms/bool/IsFinite.java | 2 +- .../api/ops/impl/transforms/bool/IsInf.java | 2 +- .../api/ops/impl/transforms/bool/IsNaN.java | 2 +- .../ops/impl/transforms/clip/ClipByNorm.java | 2 +- .../ops/impl/transforms/clip/ClipByValue.java | 4 +- .../api/ops/impl/transforms/custom/ATan2.java | 12 +- .../ops/impl/transforms/custom/Assign.java | 2 +- .../ops/impl/transforms/custom/CumProd.java | 3 +- .../ops/impl/transforms/custom/CumSum.java | 3 +- .../custom/DotProductAttention.java | 3 +- .../transforms/custom/DynamicPartition.java | 3 +- .../impl/transforms/custom/DynamicStitch.java | 2 +- .../transforms/custom/InvertPermutation.java | 5 +- .../ops/impl/transforms/custom/LayerNorm.java | 10 +- .../impl/transforms/custom/LogSoftMax.java | 7 +- .../impl/transforms/custom/MatrixSetDiag.java | 4 +- .../custom/MultiHeadDotProductAttention.java | 2 +- .../api/ops/impl/transforms/custom/Pow.java | 4 +- .../ops/impl/transforms/custom/Reverse.java | 4 +- .../transforms/custom/ReverseSequence.java | 4 +- .../ops/impl/transforms/custom/SoftMax.java | 4 +- .../impl/transforms/custom/Standardize.java | 3 +- .../impl/transforms/custom/ThresholdRelu.java | 3 +- .../api/ops/impl/transforms/custom/Trace.java | 10 +- .../transforms/custom/segment/SegmentMax.java | 3 +- .../custom/segment/SegmentMean.java | 3 +- .../transforms/custom/segment/SegmentMin.java | 3 +- .../custom/segment/SegmentProd.java | 3 +- .../transforms/custom/segment/SegmentSum.java | 3 +- .../api/ops/impl/transforms/dtype/Cast.java | 2 +- .../ops/impl/transforms/floating/RSqrt.java | 2 +- .../transforms/gradient/SELUDerivative.java | 10 +- .../transforms/gradient/TanhDerivative.java | 2 +- .../transforms/pairwise/arithmetic/AddOp.java | 12 +- .../transforms/pairwise/arithmetic/DivOp.java | 13 +- .../pairwise/arithmetic/FModOp.java | 3 +- .../pairwise/arithmetic/FloorDivOp.java | 8 +- .../pairwise/arithmetic/FloorModOp.java | 8 +- .../pairwise/arithmetic/MergeAddOp.java | 2 +- .../transforms/pairwise/arithmetic/ModOp.java | 15 +- .../transforms/pairwise/arithmetic/MulOp.java | 14 +- .../pairwise/arithmetic/RDivOp.java | 13 +- .../pairwise/arithmetic/RSubOp.java | 17 +- .../pairwise/arithmetic/RealDivOp.java | 3 +- .../arithmetic/SquaredDifferenceOp.java | 18 +- .../transforms/pairwise/arithmetic/SubOp.java | 14 +- .../pairwise/arithmetic/TruncateDivOp.java | 4 +- .../impl/transforms/pairwise/bool/Not.java | 2 +- .../api/ops/impl/transforms/same/AMax.java | 3 +- .../api/ops/impl/transforms/same/AMin.java | 3 +- .../api/ops/impl/transforms/same/Abs.java | 2 +- .../api/ops/impl/transforms/same/Ceil.java | 2 +- .../api/ops/impl/transforms/same/Cube.java | 3 +- .../api/ops/impl/transforms/same/Max.java | 6 +- .../api/ops/impl/transforms/same/Min.java | 9 +- .../ops/impl/transforms/same/Negative.java | 2 +- .../ops/impl/transforms/same/Reciprocal.java | 2 +- .../api/ops/impl/transforms/same/Round.java | 2 +- .../api/ops/impl/transforms/same/Square.java | 6 +- .../segment/UnsortedSegmentMax.java | 3 +- .../segment/UnsortedSegmentMean.java | 3 +- .../segment/UnsortedSegmentMin.java | 3 +- .../segment/UnsortedSegmentProd.java | 3 +- .../segment/UnsortedSegmentSqrtN.java | 3 +- .../segment/UnsortedSegmentSum.java | 3 +- .../api/ops/impl/transforms/strict/ACos.java | 7 +- .../api/ops/impl/transforms/strict/ASinh.java | 4 +- .../api/ops/impl/transforms/strict/ATan.java | 5 +- .../api/ops/impl/transforms/strict/Cos.java | 2 +- .../api/ops/impl/transforms/strict/Cosh.java | 2 +- .../api/ops/impl/transforms/strict/ELU.java | 3 +- .../api/ops/impl/transforms/strict/Exp.java | 2 +- .../api/ops/impl/transforms/strict/Expm1.java | 2 +- .../api/ops/impl/transforms/strict/GELU.java | 2 +- .../impl/transforms/strict/HardSigmoid.java | 3 +- .../ops/impl/transforms/strict/HardTanh.java | 3 +- .../api/ops/impl/transforms/strict/Log.java | 6 +- .../api/ops/impl/transforms/strict/Log1p.java | 3 +- .../impl/transforms/strict/LogSigmoid.java | 7 +- .../api/ops/impl/transforms/strict/Mish.java | 5 +- .../impl/transforms/strict/PreciseGELU.java | 6 +- .../impl/transforms/strict/RationalTanh.java | 7 +- .../impl/transforms/strict/RectifiedTanh.java | 7 +- .../api/ops/impl/transforms/strict/SELU.java | 3 +- .../ops/impl/transforms/strict/Sigmoid.java | 4 +- .../api/ops/impl/transforms/strict/Sin.java | 5 +- .../api/ops/impl/transforms/strict/Sinh.java | 5 +- .../ops/impl/transforms/strict/SoftPlus.java | 5 +- .../ops/impl/transforms/strict/SoftSign.java | 7 +- .../api/ops/impl/transforms/strict/Swish.java | 2 +- .../transforms/strict/SwishDerivative.java | 4 +- .../api/ops/impl/transforms/strict/Tan.java | 2 +- .../api/ops/impl/transforms/strict/Tanh.java | 4 +- .../org/nd4j/linalg/factory/ops/NDBase.java | 53 +- .../org/nd4j/linalg/factory/ops/NDImage.java | 4 +- .../org/nd4j/linalg/factory/ops/NDMath.java | 316 +- .../org/nd4j/linalg/factory/ops/NDNN.java | 31 +- .../nd4j-tests-tensorflow/pom.xml | 13 +- nd4j/nd4j-backends/nd4j-tests/pom.xml | 23 +- .../org/nd4j/AssertTestsExtendBaseClass.java | 68 +- .../java/org/nd4j/autodiff/TestSessions.java | 6 +- .../opvalidation/MiscOpValidation.java | 22 +- .../opvalidation/ReductionOpValidation.java | 4 +- .../opvalidation/ShapeOpValidation.java | 14 +- .../opvalidation/TransformOpValidation.java | 25 +- .../samediff/FlatBufferSerdeTest.java | 4 +- .../autodiff/samediff/NameScopeTests.java | 2 +- .../nd4j/autodiff/samediff/SameDiffTests.java | 119 +- .../samediff/SameDiffTrainingTest.java | 6 +- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 66 + nd4j/nd4j-common-tests/pom.xml | 16 + .../org/nd4j/AbstractAssertTestsClass.java | 82 + nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml | 9 +- .../pom.xml | 6 +- .../nd4j-parameter-server-status/pom.xml | 6 +- nd4j/nd4j-remote/nd4j-json-server/pom.xml | 342 +-- nd4j/nd4j-serde/nd4j-aeron/pom.xml | 329 +- nd4j/nd4j-serde/nd4j-arrow/pom.xml | 19 +- .../nd4j-camel-routes/nd4j-kafka/pom.xml | 3 +- nd4j/nd4j-serde/nd4j-gson/pom.xml | 28 +- nd4j/nd4j-serde/nd4j-kryo/pom.xml | 36 +- 358 files changed, 4531 insertions(+), 3919 deletions(-) create mode 100644 arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java create mode 100644 arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java create mode 100644 arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java create mode 100644 arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-api/src/test/java/org/datavec/api/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-data/datavec-data-audio/src/test/java/org/datavec/audio/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-data/datavec-data-codec/src/test/java/org/datavec/codec/reader/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-data/datavec-data-image/src/test/java/org/datavec/image/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-excel/src/test/java/org/datavec/poi/excel/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-local/src/test/java/org/datavec/local/transforms/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java create mode 100644 datavec/datavec-spark/src/test/java/org/datavec/spark/AssertTestsExtendBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/AssertTestsExtendBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/AssertTestsExtendBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/AssertTestsExtendBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java create mode 100644 deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java delete mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PadMode.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/WeightsFormat.java create mode 100644 nd4j/nd4j-common-tests/src/main/java/org/nd4j/AbstractAssertTestsClass.java diff --git a/arbiter/arbiter-core/pom.xml b/arbiter/arbiter-core/pom.xml index 6ce3c9c1f..76251f4cd 100644 --- a/arbiter/arbiter-core/pom.xml +++ b/arbiter/arbiter-core/pom.xml @@ -14,7 +14,8 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - + arbiter org.deeplearning4j @@ -33,10 +34,10 @@ nd4j-api ${nd4j.version} - - com.google.code.findbugs - * - + + com.google.code.findbugs + * + diff --git a/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..75a64d05f --- /dev/null +++ b/arbiter/arbiter-core/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.arbiter.optimize; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.arbiter.optimize"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..b8d200350 --- /dev/null +++ b/arbiter/arbiter-deeplearning4j/src/test/java/org/deeplearning4j/arbiter/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.arbiter; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.arbiter"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..b305b123b --- /dev/null +++ b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.arbiter.server; + +import lombok.extern.slf4j.Slf4j; +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.arbiter.server"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java index dbf05d34f..57bef758d 100644 --- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java +++ b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/MnistDataSetIteratorFactory.java @@ -17,6 +17,7 @@ package org.deeplearning4j.arbiter.server; import lombok.Data; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; @@ -27,7 +28,7 @@ import java.io.IOException; * Created by agibsonccc on 3/13/17. */ @Data -public class MnistDataSetIteratorFactory implements DataSetIteratorFactory { +public class MnistDataSetIteratorFactory extends BaseDL4JTest implements DataSetIteratorFactory { /** * @return */ diff --git a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java index e1a7f820e..c4a75ffb4 100644 --- a/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java +++ b/arbiter/arbiter-server/src/test/java/org/deeplearning4j/arbiter/server/TestDataFactoryProviderMnist.java @@ -17,13 +17,14 @@ package org.deeplearning4j.arbiter.server; import lombok.AllArgsConstructor; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.datasets.iterator.EarlyTerminationDataSetIterator; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIteratorFactory; @AllArgsConstructor -public class TestDataFactoryProviderMnist implements DataSetIteratorFactory { +public class TestDataFactoryProviderMnist extends BaseDL4JTest implements DataSetIteratorFactory { private int batchSize; private int terminationIter; diff --git a/arbiter/arbiter-ui/pom.xml b/arbiter/arbiter-ui/pom.xml index 2067a3fc7..7392392db 100644 --- a/arbiter/arbiter-ui/pom.xml +++ b/arbiter/arbiter-ui/pom.xml @@ -54,6 +54,13 @@ ${dl4j.version} + + org.deeplearning4j + deeplearning4j-common-tests + ${dl4j.version} + test + + ch.qos.logback logback-classic diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..fee20847c --- /dev/null +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.arbiter.optimize; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.deeplearning4j.BaseDL4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.arbiter.optimize"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java index 804c6f974..ddf73e455 100644 --- a/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java +++ b/arbiter/arbiter-ui/src/test/java/org/deeplearning4j/arbiter/optimize/TestBasic.java @@ -16,6 +16,7 @@ package org.deeplearning4j.arbiter.optimize; +import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.api.storage.StatsStorage; import org.deeplearning4j.arbiter.ComputationGraphSpace; import org.deeplearning4j.arbiter.MultiLayerSpace; @@ -70,7 +71,7 @@ import java.util.concurrent.TimeUnit; /** * Created by Alex on 19/07/2017. */ -public class TestBasic { +public class TestBasic extends BaseDL4JTest { @Test @Ignore diff --git a/datavec/datavec-api/pom.xml b/datavec/datavec-api/pom.xml index 10ed3517a..3c3eec86e 100644 --- a/datavec/datavec-api/pom.xml +++ b/datavec/datavec-api/pom.xml @@ -15,7 +15,8 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - + datavec-parent org.datavec @@ -79,6 +80,14 @@ ${nd4j.version} + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + + + ch.qos.logback logback-classic diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java index c55d4d3bb..92a1f737b 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/util/ndarray/RecordConverter.java @@ -47,9 +47,8 @@ public class RecordConverter { * * @return the array */ - @Deprecated - public static INDArray toArray(Collection record, int size) { - return toArray(record); + public static INDArray toArray(DataType dataType, Collection record, int size) { + return toArray(dataType, record); } /** @@ -78,13 +77,23 @@ public class RecordConverter { /** * Convert a set of records in to a matrix + * As per {@link #toMatrix(DataType, List)} but hardcoded to Float datatype * @param records the records ot convert * @return the matrix for the records */ public static INDArray toMatrix(List> records) { + return toMatrix(DataType.FLOAT, records); + } + + /** + * Convert a set of records in to a matrix + * @param records the records ot convert + * @return the matrix for the records + */ + public static INDArray toMatrix(DataType dataType, List> records) { List toStack = new ArrayList<>(); for(List l : records){ - toStack.add(toArray(l)); + toStack.add(toArray(dataType, l)); } return Nd4j.vstack(toStack); @@ -92,10 +101,20 @@ public class RecordConverter { /** * Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables. + * As per {@link #toArray(DataType, Collection)} but hardcoded to Float datatype * @param record the record to convert * @return the array */ - public static INDArray toArray(Collection record) { + public static INDArray toArray(Collection record){ + return toArray(DataType.FLOAT, record); + } + + /** + * Convert a record to an INDArray. May contain a mix of Writables and row vector NDArrayWritables. + * @param record the record to convert + * @return the array + */ + public static INDArray toArray(DataType dataType, Collection record) { List l; if(record instanceof List){ l = (List)record; @@ -124,7 +143,7 @@ public class RecordConverter { } } - INDArray arr = Nd4j.create(1, length); + INDArray arr = Nd4j.create(dataType, 1, length); int k = 0; for (Writable w : record ) { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/AssertTestsExtendBaseClass.java b/datavec/datavec-api/src/test/java/org/datavec/api/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..43c606123 --- /dev/null +++ b/datavec/datavec-api/src/test/java/org/datavec/api/AssertTestsExtendBaseClass.java @@ -0,0 +1,57 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.api; + +import lombok.extern.slf4j.Slf4j; +import org.datavec.api.transform.serde.testClasses.CustomCondition; +import org.datavec.api.transform.serde.testClasses.CustomFilter; +import org.datavec.api.transform.serde.testClasses.CustomTransform; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + Set> res = new HashSet<>(); + res.add(CustomCondition.class); + res.add(CustomFilter.class); + res.add(CustomTransform.class); + return res; + } + + @Override + protected String getPackageName() { + return "org.datavec.api"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java index 70a7ffa7b..84d9b259f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVLineSequenceRecordReaderTest.java @@ -25,6 +25,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import java.io.File; import java.nio.charset.StandardCharsets; @@ -34,7 +35,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class CSVLineSequenceRecordReaderTest { +public class CSVLineSequenceRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java index 888d8b523..c293d4544 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVMultiSequenceRecordReaderTest.java @@ -26,6 +26,8 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.BaseCompatOp; import java.io.File; import java.nio.charset.StandardCharsets; @@ -37,7 +39,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -public class CSVMultiSequenceRecordReaderTest { +public class CSVMultiSequenceRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java index 9b84fddc3..9f297d83b 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVNLinesSequenceRecordReaderTest.java @@ -24,6 +24,7 @@ import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.util.ArrayList; @@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 19/09/2016. */ -public class CSVNLinesSequenceRecordReaderTest { +public class CSVNLinesSequenceRecordReaderTest extends BaseND4JTest { @Test public void testCSVNLinesSequenceRecordReader() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java index 534cc986e..471dc07c4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVRecordReaderTest.java @@ -31,6 +31,7 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; @@ -44,7 +45,7 @@ import java.util.NoSuchElementException; import static org.junit.Assert.*; -public class CSVRecordReaderTest { +public class CSVRecordReaderTest extends BaseND4JTest { @Test public void testNext() throws Exception { CSVRecordReader reader = new CSVRecordReader(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java index fbbd992d1..e0763bbbc 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVSequenceRecordReaderTest.java @@ -26,6 +26,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; @@ -39,7 +40,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class CSVSequenceRecordReaderTest { +public class CSVSequenceRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder tempDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java index 8e60acad9..fe0c94c4c 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/CSVVariableSlidingWindowRecordReaderTest.java @@ -22,6 +22,7 @@ import org.datavec.api.records.reader.impl.csv.CSVVariableSlidingWindowRecordRea import org.datavec.api.split.FileSplit; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.util.LinkedList; @@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals; * * @author Justin Long (crockpotveggies) */ -public class CSVVariableSlidingWindowRecordReaderTest { +public class CSVVariableSlidingWindowRecordReaderTest extends BaseND4JTest { @Test public void testCSVVariableSlidingWindowRecordReader() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java index c67e32192..d6f03d815 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileBatchRecordReaderTest.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.api.loader.FileBatch; import java.io.File; @@ -36,7 +37,7 @@ import java.util.List; import static org.junit.Assert.*; -public class FileBatchRecordReaderTest { +public class FileBatchRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java index 533f5be66..6bf66880f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/FileRecordReaderTest.java @@ -23,6 +23,7 @@ import org.datavec.api.split.FileSplit; import org.datavec.api.split.InputSplit; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.net.URI; @@ -36,7 +37,7 @@ import static org.junit.Assert.assertFalse; /** * Created by nyghtowl on 11/14/15. */ -public class FileRecordReaderTest { +public class FileRecordReaderTest extends BaseND4JTest { @Test public void testReset() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java index bfeadef36..2f91579f0 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonLineRecordReaderTest.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.ObjectMapper; @@ -39,7 +40,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class JacksonLineRecordReaderTest { +public class JacksonLineRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java index c95de48e7..f1fa8d6b2 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/JacksonRecordReaderTest.java @@ -30,6 +30,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.ObjectMapper; @@ -48,7 +49,7 @@ import static org.junit.Assert.assertFalse; /** * Created by Alex on 11/04/2016. */ -public class JacksonRecordReaderTest { +public class JacksonRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java index 75871a6b7..5e8ca6546 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LibSvmRecordReaderTest.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.IOException; @@ -44,7 +45,7 @@ import static org.junit.Assert.assertEquals; * * @author dave@skymind.io */ -public class LibSvmRecordReaderTest { +public class LibSvmRecordReaderTest extends BaseND4JTest { @Test public void testBasicRecord() throws IOException, InterruptedException { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java index 5027357eb..17a41f4d4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/LineReaderTest.java @@ -29,6 +29,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,7 +49,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 11/17/14. */ -public class LineReaderTest { +public class LineReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java index 778d14424..539b0f351 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/RegexRecordReaderTest.java @@ -32,6 +32,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; @@ -45,7 +46,7 @@ import static org.junit.Assert.assertFalse; /** * Created by Alex on 12/04/2016. */ -public class RegexRecordReaderTest { +public class RegexRecordReaderTest extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java index 92f8c57e4..25d2959ce 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/SVMLightRecordReaderTest.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.IOException; @@ -42,7 +43,7 @@ import static org.junit.Assert.assertEquals; * * @author dave@skymind.io */ -public class SVMLightRecordReaderTest { +public class SVMLightRecordReaderTest extends BaseND4JTest { @Test public void testBasicRecord() throws IOException, InterruptedException { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java index a06d56400..fa68c4a1f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestCollectionRecordReaders.java @@ -23,6 +23,7 @@ import org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordRe import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -33,7 +34,7 @@ import static org.junit.Assert.*; /** * Created by Alex on 21/05/2016. */ -public class TestCollectionRecordReaders { +public class TestCollectionRecordReaders extends BaseND4JTest { @Test public void testCollectionSequenceRecordReader() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java index 172d884a3..266ad2edc 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestConcatenatingRecordReader.java @@ -20,11 +20,12 @@ import org.datavec.api.records.reader.RecordReader; import org.datavec.api.records.reader.impl.csv.CSVRecordReader; import org.datavec.api.split.FileSplit; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import static org.junit.Assert.assertEquals; -public class TestConcatenatingRecordReader { +public class TestConcatenatingRecordReader extends BaseND4JTest { @Test public void test() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java index c249737a3..91fc22886 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/TestSerialization.java @@ -34,6 +34,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonFactory; import org.nd4j.shade.jackson.databind.ObjectMapper; @@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals; * Note however that not all are used/usable with spark (such as Collection[Sequence]RecordReader * and the rest are generally used without being initialized on a particular dataset */ -public class TestSerialization { +public class TestSerialization extends BaseND4JTest { @Test public void testRR() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java index 5daad01b3..ff3ceb9be 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/reader/impl/transform/TransformProcessRecordReaderTests.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.util.ArrayList; @@ -39,7 +40,7 @@ import static org.junit.Assert.assertTrue; /** * Created by agibsonccc on 3/21/17. */ -public class TransformProcessRecordReaderTests { +public class TransformProcessRecordReaderTests extends BaseND4JTest { @Test public void simpleTransformTest() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java index 5a165b0ac..c3a8f4181 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/CSVRecordWriterTest.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.io.File; import java.util.ArrayList; @@ -34,7 +35,7 @@ import static org.junit.Assert.assertEquals; /** * @author raver119@gmail.com */ -public class CSVRecordWriterTest { +public class CSVRecordWriterTest extends BaseND4JTest { @Before public void setUp() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java index 0c7d70b09..91996056d 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/LibSvmRecordWriterTest.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; @@ -49,7 +50,7 @@ import static org.junit.Assert.assertEquals; * * @author dave@skymind.io */ -public class LibSvmRecordWriterTest { +public class LibSvmRecordWriterTest extends BaseND4JTest { @Test public void testBasic() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java index 7b9b8c203..f057c7d45 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/records/writer/impl/SVMLightRecordWriterTest.java @@ -25,6 +25,7 @@ import org.datavec.api.split.partition.NumberOfRecordsPartitioner; import org.datavec.api.writable.*; import org.datavec.api.writable.NDArrayWritable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.io.ClassPathResource; @@ -47,7 +48,7 @@ import static org.junit.Assert.assertEquals; * * @author dave@skymind.io */ -public class SVMLightRecordWriterTest { +public class SVMLightRecordWriterTest extends BaseND4JTest { @Test public void testBasic() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java index e8ce37bd3..59e1feee8 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/InputSplitTests.java @@ -16,6 +16,7 @@ package org.datavec.api.split; +import org.nd4j.BaseND4JTest; import org.nd4j.shade.guava.io.Files; import org.datavec.api.io.filters.BalancedPathFilter; import org.datavec.api.io.filters.RandomPathFilter; @@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals; * * @author saudet */ -public class InputSplitTests { +public class InputSplitTests extends BaseND4JTest { @Test public void testSample() throws URISyntaxException { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java index 797f546dd..f8be04d47 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/NumberedFileInputSplitTests.java @@ -17,13 +17,14 @@ package org.datavec.api.split; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.net.URI; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -public class NumberedFileInputSplitTests { +public class NumberedFileInputSplitTests extends BaseND4JTest { @Test public void testNumberedFileInputSplitBasic() { String baseString = "/path/to/files/prefix%d.suffix"; diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java index c618c625d..94119015c 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/TestStreamInputSplit.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.Writable; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.function.Function; import java.io.File; @@ -40,7 +41,7 @@ import java.util.Random; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; -public class TestStreamInputSplit { +public class TestStreamInputSplit extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java index 457f07097..ea6b9fea4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/TransformSplitTest.java @@ -17,6 +17,7 @@ package org.datavec.api.split; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.net.URI; import java.net.URISyntaxException; @@ -28,7 +29,7 @@ import static org.junit.Assert.assertArrayEquals; /** * @author Ede Meijer */ -public class TransformSplitTest { +public class TransformSplitTest extends BaseND4JTest { @Test public void testTransform() throws URISyntaxException { Collection inputFiles = asList(new URI("file:///foo/bar/../0.csv"), new URI("file:///foo/1.csv")); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java index c9fb57eb9..f27f7527f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/split/parittion/PartitionerTests.java @@ -16,6 +16,7 @@ package org.datavec.api.split.parittion; +import org.nd4j.BaseND4JTest; import org.nd4j.shade.guava.io.Files; import org.datavec.api.conf.Configuration; import org.datavec.api.split.FileSplit; @@ -31,7 +32,7 @@ import static junit.framework.TestCase.assertTrue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; -public class PartitionerTests { +public class PartitionerTests extends BaseND4JTest { @Test public void testRecordsPerFilePartition() { Partitioner partitioner = new NumberOfRecordsPartitioner(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java index eeb4be27a..efb9f2b6e 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/TestTransformProcess.java @@ -26,12 +26,13 @@ import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; import static org.junit.Assert.assertEquals; -public class TestTransformProcess { +public class TestTransformProcess extends BaseND4JTest { @Test public void testExecution(){ diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java index da4e53398..0c69959d6 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/condition/TestConditions.java @@ -24,6 +24,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.transform.TestTransforms; import org.datavec.api.writable.*; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; @@ -33,7 +34,7 @@ import static org.junit.Assert.assertTrue; /** * Created by Alex on 24/03/2016. */ -public class TestConditions { +public class TestConditions extends BaseND4JTest { @Test public void testIntegerCondition() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java index 4d96b5b6e..314ee72ff 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/filter/TestFilters.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.IntWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -37,7 +38,7 @@ import static org.junit.Assert.assertTrue; /** * Created by Alex on 21/03/2016. */ -public class TestFilters { +public class TestFilters extends BaseND4JTest { @Test diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java index e6ae74185..1d113c6ff 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/join/TestJoin.java @@ -23,6 +23,7 @@ import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -33,7 +34,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 18/04/2016. */ -public class TestJoin { +public class TestJoin extends BaseND4JTest { @Test public void testJoin() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java index 059cb618c..57ec54e8a 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregableMultiOpTest.java @@ -18,6 +18,7 @@ package org.datavec.api.transform.ops; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.io.Serializable; import java.util.*; @@ -27,7 +28,7 @@ import static org.junit.Assert.assertTrue; /** * Created by huitseeker on 5/14/17. */ -public class AggregableMultiOpTest { +public class AggregableMultiOpTest extends BaseND4JTest { private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java index 487926c7a..c722dada4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/AggregatorImplsTest.java @@ -19,6 +19,7 @@ package org.datavec.api.transform.ops; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -30,7 +31,7 @@ import static org.junit.Assert.assertTrue; /** * Created by huitseeker on 5/14/17. */ -public class AggregatorImplsTest { +public class AggregatorImplsTest extends BaseND4JTest { private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java index 076e4412d..a636e7239 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ops/DispatchOpTest.java @@ -18,6 +18,7 @@ package org.datavec.api.transform.ops; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -29,7 +30,7 @@ import static org.junit.Assert.assertTrue; /** * Created by huitseeker on 5/14/17. */ -public class DispatchOpTest { +public class DispatchOpTest extends BaseND4JTest { private List intList = new ArrayList<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); private List stringList = new ArrayList<>(Arrays.asList("arakoa", "abracadabra", "blast", "acceptance")); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java index 1b0a20430..9aef39aa4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestMultiOpReduce.java @@ -29,6 +29,7 @@ import org.datavec.api.transform.ops.IAggregableReduceOp; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; @@ -38,7 +39,7 @@ import static org.junit.Assert.fail; /** * Created by Alex on 21/03/2016. */ -public class TestMultiOpReduce { +public class TestMultiOpReduce extends BaseND4JTest { @Test public void testMultiOpReducerDouble() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java index f9debfe2c..dc6443630 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/reduce/TestReductions.java @@ -21,13 +21,14 @@ import org.datavec.api.transform.reduce.impl.GeographicMidpointReduction; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.Arrays; import java.util.List; import static org.junit.Assert.assertEquals; -public class TestReductions { +public class TestReductions extends BaseND4JTest { @Test public void testGeographicMidPointReduction(){ diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java index dff90f8b9..8e33b742c 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestJsonYaml.java @@ -19,13 +19,14 @@ package org.datavec.api.transform.schema; import org.datavec.api.transform.metadata.ColumnMetaData; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import static org.junit.Assert.assertEquals; /** * Created by Alex on 18/07/2016. */ -public class TestJsonYaml { +public class TestJsonYaml extends BaseND4JTest { @Test public void testToFromJsonYaml() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java index 870c10680..6cbcafff4 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/schema/TestSchemaMethods.java @@ -18,13 +18,14 @@ package org.datavec.api.transform.schema; import org.datavec.api.transform.ColumnType; import org.junit.Test; +import org.nd4j.BaseND4JTest; import static org.junit.Assert.assertEquals; /** * Created by Alex on 04/09/2016. */ -public class TestSchemaMethods { +public class TestSchemaMethods extends BaseND4JTest { @Test public void testNumberedColumnAdding() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java index 0f48eeff3..56c8d3f1e 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestReduceSequenceByWindowFunction.java @@ -30,6 +30,7 @@ import org.datavec.api.writable.NullWritable; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -41,7 +42,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 16/04/2016. */ -public class TestReduceSequenceByWindowFunction { +public class TestReduceSequenceByWindowFunction extends BaseND4JTest { @Test public void testReduceSequenceByWindowFunction() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java index 6695599a5..98dd49587 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestSequenceSplit.java @@ -24,6 +24,7 @@ import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -35,7 +36,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 19/04/2016. */ -public class TestSequenceSplit { +public class TestSequenceSplit extends BaseND4JTest { @Test public void testSequenceSplitTimeSeparation() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java index 99fe9227d..cc12adc53 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/sequence/TestWindowFunctions.java @@ -26,6 +26,7 @@ import org.datavec.api.writable.LongWritable; import org.datavec.api.writable.Writable; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -37,7 +38,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 16/04/2016. */ -public class TestWindowFunctions { +public class TestWindowFunctions extends BaseND4JTest { @Test public void testTimeWindowFunction() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java index 03731f6d6..1da9f48e5 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestCustomTransformJsonYaml.java @@ -23,13 +23,14 @@ import org.datavec.api.transform.serde.testClasses.CustomCondition; import org.datavec.api.transform.serde.testClasses.CustomFilter; import org.datavec.api.transform.serde.testClasses.CustomTransform; import org.junit.Test; +import org.nd4j.BaseND4JTest; import static org.junit.Assert.assertEquals; /** * Created by Alex on 11/01/2017. */ -public class TestCustomTransformJsonYaml { +public class TestCustomTransformJsonYaml extends BaseND4JTest { @Test public void testCustomTransform() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java index d09995009..dd6e0941a 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/serde/TestYamlJsonSerde.java @@ -61,6 +61,7 @@ import org.datavec.api.writable.comparator.DoubleWritableComparator; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; import java.util.concurrent.TimeUnit; @@ -70,7 +71,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 20/07/2016. */ -public class TestYamlJsonSerde { +public class TestYamlJsonSerde extends BaseND4JTest { public static YamlSerializer y = new YamlSerializer(); public static JsonSerializer j = new JsonSerializer(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java index e17650a87..ac69e3397 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/stringreduce/TestReduce.java @@ -21,6 +21,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; @@ -29,7 +30,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 21/03/2016. */ -public class TestReduce { +public class TestReduce extends BaseND4JTest { @Test public void testReducerDouble() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java index 38ec1fda9..daa5c15c8 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/RegressionTestJson.java @@ -47,6 +47,7 @@ import org.datavec.api.writable.comparator.LongWritableComparator; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.io.ClassPathResource; import java.io.File; @@ -58,7 +59,7 @@ import java.util.concurrent.TimeUnit; import static org.junit.Assert.assertEquals; -public class RegressionTestJson { +public class RegressionTestJson extends BaseND4JTest { @Test public void regressionTestJson100a() throws Exception { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java index 9f647d365..00c4b745f 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestJsonYaml.java @@ -47,6 +47,7 @@ import org.datavec.api.writable.comparator.LongWritableComparator; import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.*; import java.util.concurrent.TimeUnit; @@ -56,7 +57,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 18/07/2016. */ -public class TestJsonYaml { +public class TestJsonYaml extends BaseND4JTest { @Test public void testToFromJsonYaml() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java index 600ee0b25..1d440913b 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/TestTransforms.java @@ -56,6 +56,7 @@ import org.joda.time.DateTimeFieldType; import org.joda.time.DateTimeZone; import org.junit.Assert; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -72,7 +73,7 @@ import static org.junit.Assert.*; /** * Created by Alex on 21/03/2016. */ -public class TestTransforms { +public class TestTransforms extends BaseND4JTest { public static Schema getSchema(ColumnType type, String... colNames) { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java index 78d929e65..c6dad8359 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestNDArrayWritableTransforms.java @@ -26,6 +26,7 @@ import org.datavec.api.writable.NDArrayWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -39,7 +40,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 02/06/2017. */ -public class TestNDArrayWritableTransforms { +public class TestNDArrayWritableTransforms extends BaseND4JTest { @Test public void testNDArrayWritableBasic() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java index 7eb3efdef..394457443 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/ndarray/TestYamlJsonSerde.java @@ -27,6 +27,7 @@ import org.datavec.api.transform.schema.Schema; import org.datavec.api.transform.serde.JsonSerializer; import org.datavec.api.transform.serde.YamlSerializer; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.Arrays; import java.util.List; @@ -36,7 +37,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 20/07/2016. */ -public class TestYamlJsonSerde { +public class TestYamlJsonSerde extends BaseND4JTest { public static YamlSerializer y = new YamlSerializer(); public static JsonSerializer j = new JsonSerializer(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java index 8ec5233c7..4c2c718ae 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/transform/parse/ParseDoubleTransformTest.java @@ -20,6 +20,7 @@ import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Text; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -30,7 +31,7 @@ import static org.junit.Assert.assertEquals; /** * Created by agibsonccc on 10/22/16. */ -public class ParseDoubleTransformTest { +public class ParseDoubleTransformTest extends BaseND4JTest { @Test public void testDoubleTransform() { List record = new ArrayList<>(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java index 0dffb6dab..64f6a4422 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/transform/ui/TestUI.java @@ -35,6 +35,7 @@ import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import java.io.File; import java.util.ArrayList; @@ -46,7 +47,7 @@ import static org.junit.Assert.assertEquals; /** * Created by Alex on 25/03/2016. */ -public class TestUI { +public class TestUI extends BaseND4JTest { @Rule public TemporaryFolder testDir = new TemporaryFolder(); diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java index 9b95bbfb4..b68ae43ee 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/util/ClassPathResourceTest.java @@ -18,6 +18,7 @@ package org.datavec.api.util; import org.junit.Before; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.io.BufferedReader; import java.io.File; @@ -33,7 +34,7 @@ import static org.hamcrest.core.IsEqual.equalTo; /** * @author raver119@gmail.com */ -public class ClassPathResourceTest { +public class ClassPathResourceTest extends BaseND4JTest { private boolean isWindows = false; //File sizes are reported slightly different on Linux vs. Windows diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java index 1545938f6..d47ec60d7 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/util/TimeSeriesUtilsTest.java @@ -20,6 +20,7 @@ import org.datavec.api.timeseries.util.TimeSeriesWritableUtils; import org.datavec.api.writable.DoubleWritable; import org.datavec.api.writable.Writable; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import java.util.ArrayList; @@ -27,7 +28,7 @@ import java.util.List; import static org.junit.Assert.assertArrayEquals; -public class TimeSeriesUtilsTest { +public class TimeSeriesUtilsTest extends BaseND4JTest { @Test public void testTimeSeriesCreation() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java index 6dfacdd93..dbc62ed93 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/RecordConverterTest.java @@ -16,6 +16,7 @@ package org.datavec.api.writable; +import org.nd4j.BaseND4JTest; import org.nd4j.shade.guava.collect.Lists; import org.datavec.api.transform.schema.Schema; import org.datavec.api.util.ndarray.RecordConverter; @@ -31,7 +32,7 @@ import java.util.TimeZone; import static org.junit.Assert.assertEquals; -public class RecordConverterTest { +public class RecordConverterTest extends BaseND4JTest { @Test public void toRecords_PassInClassificationDataSet_ExpectNDArrayAndIntWritables() { INDArray feature1 = Nd4j.create(new double[]{4, -5.7, 10, -0.1}, new long[]{1, 4}, DataType.FLOAT); @@ -86,7 +87,7 @@ public class RecordConverterTest { new IntWritable(1)); INDArray exp = Nd4j.create(new double[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 1}, new long[]{1, 10}, DataType.FLOAT); - INDArray act = RecordConverter.toArray(l); + INDArray act = RecordConverter.toArray(DataType.FLOAT, l); assertEquals(exp, act); } @@ -101,7 +102,7 @@ public class RecordConverterTest { {1,2,3,4,5}, {6,7,8,9,10}}).castTo(DataType.FLOAT); - INDArray act = RecordConverter.toMatrix(Arrays.asList(l1,l2)); + INDArray act = RecordConverter.toMatrix(DataType.FLOAT, Arrays.asList(l1,l2)); assertEquals(exp, act); } diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java index 81c2f2d73..9242927e1 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/TestNDArrayWritableAndSerialization.java @@ -18,6 +18,7 @@ package org.datavec.api.writable; import org.datavec.api.transform.metadata.NDArrayMetaData; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -28,7 +29,7 @@ import static org.junit.Assert.*; /** * Created by Alex on 02/06/2017. */ -public class TestNDArrayWritableAndSerialization { +public class TestNDArrayWritableAndSerialization extends BaseND4JTest { @Test public void testIsValid() { diff --git a/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java b/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java index 93d7ed31b..bd636e62b 100644 --- a/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java +++ b/datavec/datavec-api/src/test/java/org/datavec/api/writable/WritableTest.java @@ -18,6 +18,7 @@ package org.datavec.api.writable; import org.datavec.api.writable.batch.NDArrayRecordBatch; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -31,9 +32,7 @@ import java.util.List; import static org.junit.Assert.*; -public class WritableTest { - - +public class WritableTest extends BaseND4JTest { @Test public void testWritableEqualityReflexive() { diff --git a/datavec/datavec-arrow/pom.xml b/datavec/datavec-arrow/pom.xml index 04420a5e9..60409bc53 100644 --- a/datavec/datavec-arrow/pom.xml +++ b/datavec/datavec-arrow/pom.xml @@ -49,6 +49,12 @@ arrow-format ${arrow.version} + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java index 1d8fddc0e..edd036f0a 100644 --- a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/ArrowConverterTest.java @@ -40,6 +40,7 @@ import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; @@ -56,7 +57,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -public class ArrowConverterTest { +public class ArrowConverterTest extends BaseND4JTest { private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..f2cf7ce09 --- /dev/null +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.arrow; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.BaseND4JTest; +import org.nd4j.AbstractAssertTestsClass; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.arrow"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java index 390bfdcd9..59ba5a546 100644 --- a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/RecordMapperTest.java @@ -31,6 +31,7 @@ import org.datavec.api.writable.Writable; import org.datavec.arrow.recordreader.ArrowRecordReader; import org.datavec.arrow.recordreader.ArrowRecordWriter; import org.junit.Test; +import org.nd4j.BaseND4JTest; import org.nd4j.linalg.primitives.Triple; import java.io.File; @@ -41,7 +42,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; -public class RecordMapperTest { +public class RecordMapperTest extends BaseND4JTest { @Test public void testMultiWrite() throws Exception { diff --git a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java index 6951561cd..e49a9fcc4 100644 --- a/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java +++ b/datavec/datavec-arrow/src/test/java/org/datavec/arrow/recordreader/ArrowWritableRecordTimeSeriesBatchTests.java @@ -27,6 +27,7 @@ import org.datavec.api.writable.Writable; import org.datavec.arrow.ArrowConverter; import org.junit.Ignore; import org.junit.Test; +import org.nd4j.BaseND4JTest; import java.util.ArrayList; import java.util.Arrays; @@ -35,7 +36,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; -public class ArrowWritableRecordTimeSeriesBatchTests { +public class ArrowWritableRecordTimeSeriesBatchTests extends BaseND4JTest { private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); diff --git a/datavec/datavec-data/datavec-data-audio/pom.xml b/datavec/datavec-data/datavec-data-audio/pom.xml index 1f99eab7c..3b9674cd9 100644 --- a/datavec/datavec-data/datavec-data-audio/pom.xml +++ b/datavec/datavec-data/datavec-data-audio/pom.xml @@ -57,6 +57,13 @@ with-dependencies + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + + - + datavec-parent org.datavec @@ -31,6 +32,12 @@ datavec-api ${project.version} + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + com.maxmind.geoip2 geoip2 diff --git a/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..7d4a6836c --- /dev/null +++ b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.api.transform; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.api.transform"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-data/datavec-hadoop/pom.xml b/datavec/datavec-data/datavec-hadoop/pom.xml index 7b74ead38..a6c72b968 100644 --- a/datavec/datavec-data/datavec-hadoop/pom.xml +++ b/datavec/datavec-data/datavec-hadoop/pom.xml @@ -60,6 +60,13 @@ + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/AssertTestsExtendBaseClass.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..2aaf25041 --- /dev/null +++ b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/AssertTestsExtendBaseClass.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.hadoop; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.hadoop"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-excel/pom.xml b/datavec/datavec-excel/pom.xml index 00fc890d8..49dc26db8 100644 --- a/datavec/datavec-excel/pom.xml +++ b/datavec/datavec-excel/pom.xml @@ -51,6 +51,13 @@ poi-ooxml ${poi.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/AssertTestsExtendBaseClass.java b/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..1b61f7f6c --- /dev/null +++ b/datavec/datavec-excel/src/test/java/org/datavec/poi/excel/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.poi.excel; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.poi.excel"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-jdbc/pom.xml b/datavec/datavec-jdbc/pom.xml index bfafd25d0..6ef9b0441 100644 --- a/datavec/datavec-jdbc/pom.xml +++ b/datavec/datavec-jdbc/pom.xml @@ -58,6 +58,13 @@ ${derby.version} test + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/AssertTestsExtendBaseClass.java b/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..1db810b7b --- /dev/null +++ b/datavec/datavec-jdbc/src/test/java/org/datavec/api/records/reader/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.api.records.reader; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.api.records.reader"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-local/pom.xml b/datavec/datavec-local/pom.xml index 5c2c6f4ac..3adc0e011 100644 --- a/datavec/datavec-local/pom.xml +++ b/datavec/datavec-local/pom.xml @@ -81,6 +81,13 @@ test + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + + diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/AssertTestsExtendBaseClass.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..991b8466d --- /dev/null +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.local.transforms; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.local.transforms"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java index ba7048547..1a46789ad 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/analysis/TestAnalyzeLocal.java @@ -28,6 +28,7 @@ import org.datavec.local.transforms.AnalyzeLocal; import org.junit.Rule; import org.junit.Test; import org.junit.rules.TemporaryFolder; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.io.ClassPathResource; @@ -63,7 +64,7 @@ public class TestAnalyzeLocal { list.add(rr.next()); } - INDArray arr = RecordConverter.toMatrix(list); + INDArray arr = RecordConverter.toMatrix(DataType.DOUBLE, list); INDArray mean = arr.mean(0); INDArray std = arr.std(0); diff --git a/datavec/datavec-python/pom.xml b/datavec/datavec-python/pom.xml index 55cf6c5da..526b8238a 100644 --- a/datavec/datavec-python/pom.xml +++ b/datavec/datavec-python/pom.xml @@ -64,6 +64,13 @@ nd4j-native-api ${project.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/AssertTestsExtendBaseClass.java b/datavec/datavec-python/src/test/java/org/datavec/python/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..83aa2fe5a --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.python; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseND4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.python"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml index 10cca8e5a..57e50d127 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml @@ -51,6 +51,13 @@ datavec-spark-inference-model ${project.parent.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..3bff86e98 --- /dev/null +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.transform.client; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.transform.client"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml index 470340dc1..bac20d42e 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml @@ -45,6 +45,13 @@ datavec-local ${project.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java index 67d7fe44a..f76e9885f 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/transform/CSVSparkTransform.java @@ -33,6 +33,7 @@ import org.datavec.spark.transform.model.Base64NDArrayBody; import org.datavec.spark.transform.model.BatchCSVRecord; import org.datavec.spark.transform.model.SequenceBatchCSVRecord; import org.datavec.spark.transform.model.SingleCSVRecord; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.serde.base64.Nd4jBase64; @@ -91,7 +92,7 @@ public class CSVSparkTransform { transformProcess.getInitialSchema(),record.getValues()), transformProcess.getInitialSchema()); List finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); - INDArray convert = RecordConverter.toArray(finalRecord); + INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord); return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); } diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..4c6f529b9 --- /dev/null +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.spark.transform; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.spark.transform"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml index 47951b1aa..0c05f327b 100644 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml @@ -164,6 +164,13 @@ spark-core_2.11 ${spark.version} + + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..4c6f529b9 --- /dev/null +++ b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.spark.transform; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.spark.transform"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark/pom.xml b/datavec/datavec-spark/pom.xml index 50194c91e..345b774c3 100644 --- a/datavec/datavec-spark/pom.xml +++ b/datavec/datavec-spark/pom.xml @@ -130,6 +130,12 @@ test + + org.nd4j + nd4j-common-tests + ${nd4j.version} + test + diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/AssertTestsExtendBaseClass.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..9539251e6 --- /dev/null +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/AssertTestsExtendBaseClass.java @@ -0,0 +1,50 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.datavec.spark; + +import lombok.extern.slf4j.Slf4j; +import org.nd4j.AbstractAssertTestsClass; +import org.nd4j.BaseND4JTest; + +import java.util.*; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + */ + +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) + return new HashSet<>(); + } + + @Override + protected String getPackageName() { + return "org.datavec.spark"; + } + + @Override + protected Class getBaseClass() { + return BaseND4JTest.class; + } +} diff --git a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java index 5352ec10d..fcc20d661 100644 --- a/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java +++ b/datavec/datavec-spark/src/test/java/org/datavec/spark/transform/NormalizationTests.java @@ -104,7 +104,7 @@ public class NormalizationTests extends BaseSparkTest { } - INDArray arr = RecordConverter.toMatrix(data); + INDArray arr = RecordConverter.toMatrix(DataType.DOUBLE, data); Schema schema = builder.build(); JavaRDD> rdd = sc.parallelize(data); @@ -127,9 +127,9 @@ public class NormalizationTests extends BaseSparkTest { zeroToOne.transform(new DataSet(zeroToOnes, zeroToOnes)); INDArray zeroMeanUnitVarianceDataFrame = - RecordConverter.toMatrix(Normalization.zeromeanUnitVariance(schema, rdd).collect()); + RecordConverter.toMatrix(DataType.DOUBLE, Normalization.zeromeanUnitVariance(schema, rdd).collect()); INDArray zeroMeanUnitVarianceDataFrameZeroToOne = - RecordConverter.toMatrix(Normalization.normalize(schema, rdd).collect()); + RecordConverter.toMatrix(DataType.DOUBLE, Normalization.normalize(schema, rdd).collect()); assertEquals(standardScalered, zeroMeanUnitVarianceDataFrame); assertTrue(zeroToOnes.equalsWithEps(zeroMeanUnitVarianceDataFrameZeroToOne, 1e-1)); diff --git a/deeplearning4j/deeplearning4j-common-tests/pom.xml b/deeplearning4j/deeplearning4j-common-tests/pom.xml index 23e030df3..5a4ba921d 100644 --- a/deeplearning4j/deeplearning4j-common-tests/pom.xml +++ b/deeplearning4j/deeplearning4j-common-tests/pom.xml @@ -37,6 +37,11 @@ nd4j-api ${project.version} + + org.nd4j + nd4j-common-tests + ${nd4j.version} + ch.qos.logback logback-classic diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml index 496bb6b1b..90c88d4c3 100644 --- a/deeplearning4j/deeplearning4j-core/pom.xml +++ b/deeplearning4j/deeplearning4j-core/pom.xml @@ -164,20 +164,6 @@ oshi-core ${oshi.version} - - - - org.reflections - reflections - ${reflections.version} - test - - - com.google.code.findbugs - * - - - diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java index 5f0567094..34d4db39e 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java @@ -17,15 +17,8 @@ package org.deeplearning4j; import lombok.extern.slf4j.Slf4j; import org.junit.Test; -import org.reflections.Reflections; -import org.reflections.scanners.MethodAnnotationsScanner; -import org.reflections.util.ClasspathHelper; -import org.reflections.util.ConfigurationBuilder; - -import java.lang.reflect.Method; import java.util.*; - -import static org.junit.Assert.assertEquals; +import org.nd4j.AbstractAssertTestsClass; /** * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) @@ -33,45 +26,24 @@ import static org.junit.Assert.assertEquals; * Other than a small set of exceptions, all tests must extend this * * @author Alex Black + * @author Alexander Stoyakin */ @Slf4j -public class AssertTestsExtendBaseClass extends BaseDL4JTest { +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { @Override - public long getTimeoutMilliseconds() { - return 240000L; + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; } - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - private static final Set> exclusions = new HashSet<>(); + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } - @Test - public void checkTestClasses(){ - - Reflections reflections = new Reflections(new ConfigurationBuilder() - .setUrls(ClasspathHelper.forPackage("org.deeplearning4j")) - .setScanners(new MethodAnnotationsScanner())); - Set methods = reflections.getMethodsAnnotatedWith(Test.class); - Set> s = new HashSet<>(); - for(Method m : methods){ - s.add(m.getDeclaringClass()); - } - - List> l = new ArrayList<>(s); - Collections.sort(l, new Comparator>() { - @Override - public int compare(Class aClass, Class t1) { - return aClass.getName().compareTo(t1.getName()); - } - }); - - int count = 0; - for(Class c : l){ - if(!BaseDL4JTest.class.isAssignableFrom(c) && !exclusions.contains(c)){ - log.error("Test {} does not extend BaseDL4JTest (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", c); - count++; - } - } - assertEquals("Number of tests not extending BaseDL4JTest", 0, count); + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; } } diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java index cf75700f8..9bcb97b7d 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/samediff/CompareTrainingImplementations.java @@ -96,7 +96,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest { SDVariable z1 = a0.mmul(w1).add("prediction", b1); SDVariable a1 = sd.nn().softmax("softmax", z1); - SDVariable diff = sd.f().squaredDifference(a1, label); + SDVariable diff = sd.math().squaredDifference(a1, label); SDVariable lossMse = diff.mean(); lossMse.markAsLoss(); diff --git a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java index 2f47c2c8b..92d5d579e 100644 --- a/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java +++ b/deeplearning4j/deeplearning4j-data/deeplearning4j-datavec-iterators/src/main/java/org/deeplearning4j/datasets/datavec/RecordReaderMultiDataSetIterator.java @@ -494,7 +494,7 @@ public class RecordReaderMultiDataSetIterator implements MultiDataSetIterator, S List c = list.get(i); if (details.entireReader) { //Convert entire reader contents, without modification - INDArray converted = RecordConverter.toArray(c); + INDArray converted = RecordConverter.toArray(Nd4j.defaultFloatingPointType(), c); putExample(arr, converted, i); } else if (details.oneHot) { //Convert a single column to a one-hot representation diff --git a/deeplearning4j/deeplearning4j-graph/pom.xml b/deeplearning4j/deeplearning4j-graph/pom.xml index ebc6740d9..645b4eca2 100644 --- a/deeplearning4j/deeplearning4j-graph/pom.xml +++ b/deeplearning4j/deeplearning4j-graph/pom.xml @@ -57,7 +57,6 @@ ${project.version} test - diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java new file mode 100644 index 000000000..7341a3a2c --- /dev/null +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/AssertTestsExtendedBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.graph; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendedBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.graph"; + } + + @Override + protected Class getBaseClass() {return BaseDL4JTest.class; } +} + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..d7c03956f --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-chinese/src/test/java/org/deeplearning4j/text/tokenization/tokenizer/AssertTestsExtendBaseClass.java @@ -0,0 +1,52 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.text.tokenization.tokenizer; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} + + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..c767c3e72 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-japanese/src/test/java/AssertTestsExtendBaseClass.java @@ -0,0 +1,53 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import com.atilika.kuromoji.TestUtils; +import com.atilika.kuromoji.ipadic.RandomizedInputTest; +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + exclusions.add(TestUtils.class); + exclusions.add(RandomizedInputTest.class); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..ccf95a8ea --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-korean/src/test/java/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..85b0c39a9 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ + +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml index f27fd7a94..668c728ae 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/pom.xml @@ -14,76 +14,77 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - - 4.0.0 - - org.deeplearning4j - deeplearning4j-nlp-parent - 1.0.0-SNAPSHOT - + + 4.0.0 + + org.deeplearning4j + deeplearning4j-nlp-parent + 1.0.0-SNAPSHOT + - deeplearning4j-nlp + deeplearning4j-nlp - - - org.nd4j - nd4j-native-api - ${nd4j.version} - + + + org.nd4j + nd4j-native-api + ${nd4j.version} + - - commons-lang - commons-lang - 2.6 - - - org.deeplearning4j - deeplearning4j-core - ${project.version} - + + commons-lang + commons-lang + 2.6 + + + org.deeplearning4j + deeplearning4j-core + ${project.version} + - - org.threadly - threadly - ${threadly.version} - + + org.threadly + threadly + ${threadly.version} + - - junit - junit - test - + + junit + junit + test + - - org.mockito - mockito-core - ${mockito.version} - test - + + org.mockito + mockito-core + ${mockito.version} + test + - - ch.qos.logback - logback-classic - test - - - org.apache.commons - commons-lang3 - ${commonslang.version} - - - com.github.vinhkhuc - jfasttext - 0.4 - + + ch.qos.logback + logback-classic + test + + + org.apache.commons + commons-lang3 + ${commonslang.version} + + + com.github.vinhkhuc + jfasttext + 0.4 + - - org.deeplearning4j - deeplearning4j-common-tests - ${project.version} - test - - + + org.deeplearning4j + deeplearning4j-common-tests + ${project.version} + test + + diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..6fb3b0316 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/AssertTestsExtendBaseClass.java @@ -0,0 +1,49 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j"; + } + + @Override + protected Class getBaseClass() { + return BaseDL4JTest.class; + } +} + diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java index 60ecbf057..8fedee7b0 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected1D.java @@ -31,6 +31,7 @@ import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -179,10 +180,10 @@ public class LocallyConnected1D extends SameDiffLayer { //NCW format. if(cm == ConvolutionMode.Same) { layerInput = sameDiff.nn().pad(layerInput, - sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, paddingR}})), 0); + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, paddingR}})), PadMode.CONSTANT, 0); } else { layerInput = sameDiff.nn().pad(layerInput, - sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, padding}})), 0); + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0, 0}, {0, 0}, {padding, padding}})), PadMode.CONSTANT, 0); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java index 5044017a0..6fad9ec69 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/LocallyConnected2D.java @@ -32,6 +32,7 @@ import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.ndarray.INDArray; @@ -185,10 +186,10 @@ public class LocallyConnected2D extends SameDiffLayer { //NCHW format if(cm == ConvolutionMode.Same){ layerInput = sameDiff.nn().pad(layerInput, - sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}})), 0.0); + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], paddingBr[0]}, {padding[1], paddingBr[1]}})), PadMode.CONSTANT, 0.0); } else { layerInput = sameDiff.nn().pad(layerInput, - sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}})), 0.0); + sameDiff.constant(Nd4j.createFromArray(new int[][]{{0,0},{0,0},{padding[0], padding[0]}, {padding[1], padding[1]}})), PadMode.CONSTANT, 0.0); } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java index 712265e05..bcca695df 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffGraphVertex.java @@ -34,6 +34,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder; import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.SessionMemMgr; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -295,7 +296,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex { } //Define the function for external errors: - fn = sameDiff.f().externalErrors(layerOutput); + fn = SameDiffUtils.externalErrors(sameDiff, null, layerOutput); fn.outputVariable(); this.outputKey = outputVar.name(); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java index 64c2ea25e..ed355fdaf 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/samediff/SameDiffLayer.java @@ -29,6 +29,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.array.SingleThreadArrayHolder; import org.nd4j.autodiff.samediff.internal.InferenceSession; import org.nd4j.autodiff.samediff.internal.SessionMemMgr; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -321,7 +322,7 @@ public class SameDiffLayer extends AbstractLayer { } //Define the function for external errors: - fn = sameDiff.f().externalErrors(layerOutput); + fn = SameDiffUtils.externalErrors(sameDiff, null,layerOutput); fn.outputVariable(); this.outputKey = outputVar.name(); diff --git a/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java new file mode 100644 index 000000000..1d8c3d578 --- /dev/null +++ b/deeplearning4j/deeplearning4j-remote/deeplearning4j-json-server/src/test/java/org/deeplearning4j/remote/AssertTestsExtendBaseClass.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.remote; + +import lombok.extern.slf4j.Slf4j; +import java.util.*; + +import org.deeplearning4j.BaseDL4JTest; +import org.nd4j.AbstractAssertTestsClass; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4JTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alexander Stoyakin + */ +@Slf4j +public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { + + @Override + protected Set> getExclusions() { + Set> exclusions = new HashSet<>(); + return exclusions; + } + + @Override + protected String getPackageName() { + return "org.deeplearning4j.remote"; + } + + @Override + protected Class getBaseClass() { return BaseDL4JTest.class; } +} + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 8a629bc66..e21b2d270 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -475,6 +475,11 @@ public abstract class DifferentialFunction { return outputVariables()[0]; } + public List outputs(){ + SDVariable[] out = outputVariables(); + return out == null ? null : Arrays.asList(out); + } + public String[] outputVariablesNames(){ SDVariable[] outputVars = outputVariables(); @@ -502,14 +507,6 @@ public abstract class DifferentialFunction { */ public abstract List doDiff(List f1); - /** - * Shortcut for the {@link DifferentialFunctionFactory} - * @return - */ - public DifferentialFunctionFactory f() { - return sameDiff.f(); - } - /** * Return the arguments for a given function @@ -576,7 +573,7 @@ public abstract class DifferentialFunction { copied = true; } - SDVariable gradVar = f().add(grad, vals.get(i)); + SDVariable gradVar = var.getSameDiff().math.add(grad, vals.get(i)); vals.set(i, gradVar); sameDiff.setGradientForVariableName(var.name(), gradVar); } else { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java deleted file mode 100644 index 093e3099b..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ /dev/null @@ -1,2659 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.autodiff.functions; - -import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import lombok.Data; -import lombok.NonNull; -import lombok.val; -import org.apache.commons.lang3.ArrayUtils; -import org.nd4j.autodiff.loss.LossReduce; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.enums.DataFormat; -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.blas.params.MMulTranspose; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.NoOp; -import org.nd4j.linalg.api.ops.custom.*; -import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd; -import org.nd4j.linalg.api.ops.impl.broadcast.BiasAddGrad; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; -import org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches; -import org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex; -import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; -import org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex; -import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; -import org.nd4j.linalg.api.ops.impl.layers.convolution.*; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling2DConfig; -import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Pooling3DConfig; -import org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss; -import org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss; -import org.nd4j.linalg.api.ops.impl.loss.HingeLoss; -import org.nd4j.linalg.api.ops.impl.loss.HuberLoss; -import org.nd4j.linalg.api.ops.impl.loss.L2Loss; -import org.nd4j.linalg.api.ops.impl.loss.LogLoss; -import org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss; -import org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss; -import org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss; -import org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss; -import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss; -import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss; -import org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits; -import org.nd4j.linalg.api.ops.impl.loss.WeightedCrossEntropyLoss; -import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.LogLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.LogPoissonLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.SigmoidCrossEntropyLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyWithLogitsLossBp; -import org.nd4j.linalg.api.ops.impl.loss.bp.SparseSoftmaxCrossEntropyLossWithLogitsBp; -import org.nd4j.linalg.api.ops.impl.reduce.Mmul; -import org.nd4j.linalg.api.ops.impl.reduce.MmulBp; -import org.nd4j.linalg.api.ops.impl.reduce.Moments; -import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments; -import org.nd4j.linalg.api.ops.impl.reduce.TensorMmul; -import org.nd4j.linalg.api.ops.impl.reduce.ZeroFraction; -import org.nd4j.linalg.api.ops.impl.reduce.bool.All; -import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; -import org.nd4j.linalg.api.ops.impl.reduce.bp.*; -import org.nd4j.linalg.api.ops.impl.reduce.custom.BatchMmul; -import org.nd4j.linalg.api.ops.impl.reduce.custom.LogSumExp; -import org.nd4j.linalg.api.ops.impl.reduce.floating.AMean; -import org.nd4j.linalg.api.ops.impl.reduce.floating.Entropy; -import org.nd4j.linalg.api.ops.impl.reduce.floating.LogEntropy; -import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean; -import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm1; -import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2; -import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax; -import org.nd4j.linalg.api.ops.impl.reduce.floating.ShannonEntropy; -import org.nd4j.linalg.api.ops.impl.reduce.floating.SquaredNorm; -import org.nd4j.linalg.api.ops.impl.reduce.longer.CountNonZero; -import org.nd4j.linalg.api.ops.impl.reduce.longer.CountZero; -import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition; -import org.nd4j.linalg.api.ops.impl.reduce.same.AMax; -import org.nd4j.linalg.api.ops.impl.reduce.same.AMin; -import org.nd4j.linalg.api.ops.impl.reduce.same.ASum; -import org.nd4j.linalg.api.ops.impl.reduce.same.Max; -import org.nd4j.linalg.api.ops.impl.reduce.same.Min; -import org.nd4j.linalg.api.ops.impl.reduce.same.Prod; -import org.nd4j.linalg.api.ops.impl.reduce.same.Sum; -import org.nd4j.linalg.api.ops.impl.reduce3.CosineDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; -import org.nd4j.linalg.api.ops.impl.reduce3.Dot; -import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.HammingDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.JaccardDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; -import org.nd4j.linalg.api.ops.impl.scalar.*; -import org.nd4j.linalg.api.ops.impl.scalar.Pow; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThanOrEqual; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThanOrEqual; -import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarNotEquals; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterAdd; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterDiv; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterMax; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterMin; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterMul; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterSub; -import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; -import org.nd4j.linalg.api.ops.impl.shape.*; -import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; -import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; -import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; -import org.nd4j.linalg.api.ops.impl.summarystats.StandardDeviation; -import org.nd4j.linalg.api.ops.impl.summarystats.Variance; -import org.nd4j.linalg.api.ops.impl.transforms.Pad; -import org.nd4j.linalg.api.ops.impl.transforms.ReluLayer; -import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; -import org.nd4j.linalg.api.ops.impl.transforms.bool.IsFinite; -import org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf; -import org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN; -import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform; -import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm; -import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace; -import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; -import org.nd4j.linalg.api.ops.impl.transforms.custom.*; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd; -import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum; -import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast; -import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; -import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; -import org.nd4j.linalg.api.ops.impl.transforms.same.Abs; -import org.nd4j.linalg.api.ops.impl.transforms.same.Ceil; -import org.nd4j.linalg.api.ops.impl.transforms.same.Cube; -import org.nd4j.linalg.api.ops.impl.transforms.same.Floor; -import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; -import org.nd4j.linalg.api.ops.impl.transforms.same.Negative; -import org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal; -import org.nd4j.linalg.api.ops.impl.transforms.same.Round; -import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; -import org.nd4j.linalg.api.ops.impl.transforms.same.Square; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMax; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMean; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentMin; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentProd; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSqrtN; -import org.nd4j.linalg.api.ops.impl.transforms.segment.UnsortedSegmentSum; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMaxBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMeanBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMinBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentProdBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentSumBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMaxBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMeanBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp; -import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp; -import org.nd4j.linalg.api.ops.impl.transforms.strict.*; -import org.nd4j.linalg.api.ops.random.custom.*; -import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; -import org.nd4j.linalg.api.ops.random.impl.BinomialDistribution; -import org.nd4j.linalg.api.ops.random.impl.DropOutInverted; -import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; -import org.nd4j.linalg.api.ops.random.impl.LogNormalDistribution; -import org.nd4j.linalg.api.ops.random.impl.Range; -import org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution; -import org.nd4j.linalg.api.ops.random.impl.UniformDistribution; -import org.nd4j.linalg.api.shape.Shape; -import org.nd4j.linalg.indexing.conditions.Condition; -import org.nd4j.linalg.util.ArrayUtil; - -/** - * - */ -@Data -public class DifferentialFunctionFactory { - - protected SameDiff sameDiff; - private static Map methodNames; - - /** - * @param sameDiff - */ - public DifferentialFunctionFactory(SameDiff sameDiff) { - if (sameDiff != null) { - this.sameDiff = sameDiff; - if (methodNames == null) { - methodNames = new HashMap<>(); - Method[] methods = getClass().getDeclaredMethods(); - for (Method method : methods) - methodNames.put(method.getName().toLowerCase(), method); - } - } else { - throw new IllegalArgumentException("Input not null value."); - } - - - } - - public SameDiff sameDiff() { - return sameDiff; - } - - - public SDVariable invoke(String name, Object[] args) { - try { - return (SDVariable) methodNames.get(name).invoke(this, args); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - public ExternalErrorsFunction externalErrors(SDVariable... inputs) { - return externalErrors(null, inputs); - } - - public ExternalErrorsFunction externalErrors(Map externalGradients, SDVariable... inputs) { - Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" + - " be specified when using external errors: got %s", inputs); - ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff(), Arrays.asList(inputs), externalGradients); - fn.outputVariable(); - return fn; - } - - public SDVariable zerosLike(SDVariable input) { - return zerosLike(null, input); - } - - public SDVariable zerosLike(String name, SDVariable input) { - validateDifferentialFunctionsameDiff(input); - return new ZerosLike(name, sameDiff(), input).outputVariable(); - } - - public SDVariable zerosLike(String name, SDVariable input, DataType dataType) { - validateDifferentialFunctionsameDiff(input); - return new ZerosLike(name, sameDiff(), input, dataType).outputVariable(); - } - - public SDVariable create(String name, SDVariable shape, boolean initialize, DataType dataType) { - return create(name, shape, 'c', initialize, dataType); - } - - public SDVariable create(String name, SDVariable shape, char order, boolean initialize, DataType dataType) { - validateDifferentialFunctionsameDiff(shape); - return new Create(name, sameDiff(), shape, order, initialize, dataType).outputVariable(); - } - - public SDVariable onesLike(String name, SDVariable input, DataType dataType) { - validateDifferentialFunctionsameDiff(input); - return new OnesLike(name, sameDiff(), input, dataType).outputVariable(); - } - - public SDVariable linspace(SDVariable lower, SDVariable upper, SDVariable count, DataType dt) { - return new org.nd4j.linalg.api.ops.impl.shape.Linspace(sameDiff(), lower, upper, count, dt).outputVariable(); - } - - public SDVariable range(double from, double to, double step, DataType dataType) { - return new Range(sameDiff(), from, to, step, dataType).outputVariable(); - } - - public SDVariable range(SDVariable from, SDVariable to, SDVariable step, DataType dataType) { - return new Range(sameDiff(), from, to, step, dataType).outputVariable(); - } - - public SDVariable[] listdiff(SDVariable x, SDVariable y){ - return new ListDiff(sameDiff(), x, y).outputVariables(); - } - - public SDVariable cast(SDVariable toCast, DataType toType){ - return new Cast(sameDiff(), toCast, toType).outputVariable(); - } - - public SDVariable[] meshgrid(boolean cartesian, SDVariable... inputs) { - return new MeshGrid(sameDiff(), cartesian, inputs).outputVariables(); - } - - public SDVariable randomUniform(double min, double max, SDVariable shape, DataType dataType) { - return new DistributionUniform(sameDiff(), shape, min, max, dataType).outputVariable(); - } - - public SDVariable randomUniform(double min, double max, long... shape) { - return new UniformDistribution(sameDiff(), min, max, shape).outputVariable(); - } - - public SDVariable randomNormal(double mean, double std, SDVariable shape) { - return new RandomNormal(sameDiff(), shape, mean, std).outputVariable(); - } - - public SDVariable randomNormal(double mean, double std, long... shape) { - return new GaussianDistribution(sameDiff(), mean, std, shape).outputVariable(); - } - - public SDVariable randomBernoulli(double p, SDVariable shape) { - return new RandomBernoulli(sameDiff(), shape, p).outputVariable(); - } - - public SDVariable randomBernoulli(double p, long... shape) { - return new BernoulliDistribution(sameDiff(), p, shape).outputVariable(); - } - - public SDVariable randomBinomial(int nTrials, double p, long... shape) { - return new BinomialDistribution(sameDiff(), nTrials, p, shape).outputVariable(); - } - - public SDVariable randomLogNormal(double mean, double stdev, long... shape) { - return new LogNormalDistribution(sameDiff(), mean, stdev, shape).outputVariable(); - } - - public SDVariable randomNormalTruncated(double mean, double stdev, long... shape) { - return new TruncatedNormalDistribution(sameDiff(), mean, stdev, shape).outputVariable(); - } - - public SDVariable randomGamma(SDVariable shape, SDVariable alpha, SDVariable beta, int... seeds) { - return new RandomGamma(sameDiff(), shape, alpha, beta, seeds).outputVariable(); - } - - public SDVariable randomPoisson(SDVariable shape, SDVariable rate, int... seeds) { - return new RandomPoisson(sameDiff(), shape, rate, seeds).outputVariable(); - } - - public SDVariable randomShuffle(SDVariable values, int... seeds) { - return new RandomShuffle(sameDiff(), values, seeds).outputVariable(); - } - - /** - * Exponential distribution: P(x) = lambda * exp(-lambda * x) - * - * @param lambda Must be > 0 - * @param shape Shape of the output - */ - public SDVariable randomExponential(double lambda, SDVariable shape) { - return new RandomExponential(sameDiff(), shape, lambda).outputVariable(); - } - - - public SDVariable pad(SDVariable input, SDVariable padding, Pad.Mode mode, double padValue){ - return new Pad(sameDiff(), input, padding, mode, padValue).outputVariable(); - } - - /** - * Local response normalization operation. - * - * @param input the inputs to lrn - * @param lrnConfig the configuration - * @return - */ - public SDVariable localResponseNormalization(SDVariable input, LocalResponseNormalizationConfig lrnConfig) { - LocalResponseNormalization lrn = LocalResponseNormalization.sameDiffBuilder() - .inputFunctions(new SDVariable[]{input}) - .sameDiff(sameDiff()) - .config(lrnConfig) - .build(); - - return lrn.outputVariable(); - } - - /** - * Conv1d operation. - * - * @param input the inputs to conv1d - * @param weights conv1d weights - * @param conv1DConfig the configuration - * @return - */ - public SDVariable conv1d(SDVariable input, SDVariable weights, Conv1DConfig conv1DConfig) { - Conv1D conv1D = Conv1D.sameDiffBuilder() - .inputFunctions(new SDVariable[]{input, weights}) - .sameDiff(sameDiff()) - .config(conv1DConfig) - .build(); - - return conv1D.outputVariable(); - } - - /** - * Conv1d operation. - * - * @param input the inputs to conv1d - * @param weights conv1d weights - * @param bias conv1d bias - * @param conv1DConfig the configuration - * @return - */ - public SDVariable conv1d(SDVariable input, SDVariable weights, SDVariable bias, Conv1DConfig conv1DConfig) { - - SDVariable[] args; - - if(bias == null){ - args = new SDVariable[]{input, weights}; - } else { - args = new SDVariable[]{input, weights, bias}; - } - - Conv1D conv1D = Conv1D.sameDiffBuilder() - .inputFunctions(args) - .sameDiff(sameDiff()) - .config(conv1DConfig) - .build(); - - return conv1D.outputVariable(); - } - - /** - * Conv2d operation. - * - * @param inputs the inputs to conv2d - * @param conv2DConfig the configuration - * @return - */ - public SDVariable conv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { - Conv2D conv2D = Conv2D.sameDiffBuilder() - .inputFunctions(inputs) - .sameDiff(sameDiff()) - .config(conv2DConfig) - .build(); - - return conv2D.outputVariable(); - } - - public SDVariable upsampling2d(SDVariable input, boolean nchw, int scaleH, int scaleW) { - return new Upsampling2d(sameDiff(), input, nchw, scaleH, scaleW).outputVariable(); - } - - public SDVariable upsampling2dBp(SDVariable input, SDVariable gradient, boolean nchw, int scaleH, int scaleW) { - return new Upsampling2dDerivative(sameDiff(), input, gradient, nchw, scaleH, scaleW).outputVariable(); - } - - - /** - * Average pooling 2d operation. - * - * @param input the inputs to pooling - * @param pooling2DConfig the configuration - * @return - */ - public SDVariable avgPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { - AvgPooling2D avgPooling2D = AvgPooling2D.sameDiffBuilder() - .input(input) - .sameDiff(sameDiff()) - .config(pooling2DConfig) - .build(); - - return avgPooling2D.outputVariable(); - } - - /** - * Max pooling 2d operation. - * - * @param input the inputs to pooling - * @param pooling2DConfig the configuration - * @return - */ - public SDVariable maxPooling2d(SDVariable input, Pooling2DConfig pooling2DConfig) { - MaxPooling2D maxPooling2D = MaxPooling2D.sameDiffBuilder() - .input(input) - .sameDiff(sameDiff()) - .config(pooling2DConfig) - .build(); - - return maxPooling2D.outputVariable(); - } - - /** - * Avg pooling 3d operation. - * - * @param input the inputs to pooling - * @param pooling3DConfig the configuration - * @return - */ - public SDVariable avgPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { - pooling3DConfig.setType(Pooling3D.Pooling3DType.AVG); - return new AvgPooling3D(sameDiff(), input, pooling3DConfig).outputVariable(); - } - - - /** - * Max pooling 3d operation. - * - * @param input the inputs to pooling - * @param pooling3DConfig the configuration - * @return - */ - public SDVariable maxPooling3d(SDVariable input, Pooling3DConfig pooling3DConfig) { - pooling3DConfig.setType(Pooling3D.Pooling3DType.MAX); - return new MaxPooling3D(sameDiff(), input, pooling3DConfig).outputVariable(); - } - - - /** - * Separable Conv2d operation. - * - * @param inputs the inputs to conv2d - * @param conv2DConfig the configuration - * @return - */ - public SDVariable sconv2d(SDVariable[] inputs, Conv2DConfig conv2DConfig) { - SConv2D sconv2D = SConv2D.sameDiffSBuilder() - .inputFunctions(inputs) - .sameDiff(sameDiff()) - .conv2DConfig(conv2DConfig) - .build(); - - return sconv2D.outputVariable(); - } - - - /** - * Depth-wise Conv2d operation. This is just separable convolution with - * only the depth-wise weights specified. - * - * @param inputs the inputs to conv2d - * @param depthConv2DConfig the configuration - * @return - */ - public SDVariable depthWiseConv2d(SDVariable[] inputs, Conv2DConfig depthConv2DConfig) { - SConv2D depthWiseConv2D = SConv2D.sameDiffSBuilder() - .inputFunctions(inputs) - .sameDiff(sameDiff()) - .conv2DConfig(depthConv2DConfig) - .build(); - - return depthWiseConv2D.outputVariable(); - } - - - /** - * Deconv2d operation. - * - * @param inputs the inputs to conv2d - * @param deconv2DConfig the configuration - * @return - */ - public SDVariable deconv2d(SDVariable[] inputs, DeConv2DConfig deconv2DConfig) { - DeConv2D deconv2D = DeConv2D.sameDiffBuilder() - .inputs(inputs) - .sameDiff(sameDiff()) - .config(deconv2DConfig) - .build(); - - return deconv2D.outputVariable(); - } - - public SDVariable deconv3d(SDVariable input, SDVariable weights, SDVariable bias, DeConv3DConfig config) { - DeConv3D d = new DeConv3D(sameDiff(), input, weights, bias, config); - return d.outputVariable(); - } - - public SDVariable[] deconv3dDerivative(SDVariable input, SDVariable weights, SDVariable bias, SDVariable grad, DeConv3DConfig config) { - DeConv3DDerivative d = new DeConv3DDerivative(sameDiff(), input, weights, bias, grad, config); - return d.outputVariables(); - } - - /** - * Conv3d operation. - * - * @param inputs the inputs to conv3d - * @param conv3DConfig the configuration - * @return - */ - public SDVariable conv3d(SDVariable[] inputs, Conv3DConfig conv3DConfig) { - Conv3D conv3D = Conv3D.sameDiffBuilder() - .inputFunctions(inputs) - .config(conv3DConfig) - .sameDiff(sameDiff()) - .build(); - - val outputVars = conv3D.outputVariables(); - return outputVars[0]; - } - - - /** - * Batch norm operation. - */ - public SDVariable batchNorm(SDVariable input, SDVariable mean, - SDVariable variance, SDVariable gamma, - SDVariable beta, - boolean applyGamma, boolean applyBeta, - double epsilon, int... axis) { - BatchNorm batchNorm = BatchNorm.builder() - .inputFunctions(new SDVariable[]{input, mean, variance, gamma, beta}) - .applyGamma(applyGamma) - .applyBeta(applyBeta) - .epsilon(epsilon) - .sameDiff(sameDiff()) - .axis(axis) - .build(); - - val outputVars = batchNorm.outputVariables(); - return outputVars[0]; - } - - public SDVariable im2Col(SDVariable input, Conv2DConfig config) { - return new Im2col(sameDiff(), input, config).outputVariable(); - } - - public SDVariable im2ColBp(SDVariable im2colInput, SDVariable gradientAtOutput, Conv2DConfig config) { - return new Im2colBp(sameDiff(), im2colInput, gradientAtOutput, config).outputVariable(); - } - - public SDVariable col2Im(SDVariable input, Conv2DConfig config) { - return new Col2Im(sameDiff(), input, config).outputVariable(); - } - - public SDVariable extractImagePatches(SDVariable input, int kH, int kW, int sH, int sW, int rH, int rW, boolean sameMode){ - return new ExtractImagePatches(sameDiff(), input, new int[]{kH, kW}, new int[]{sH, sW}, new int[]{rH, rW}, sameMode).outputVariable(); - } - - public SDVariable[] moments(SDVariable input, int... axes) { - return new Moments(sameDiff(), input, axes).outputVariables(); - } - - public SDVariable[] normalizeMoments(SDVariable counts, SDVariable means, SDVariable variances, double shift) { - return new NormalizeMoments(sameDiff(), counts, means, variances, shift).outputVariables(); - } - - - public SDVariable tile(@NonNull SDVariable iX, @NonNull int[] repeat) { - return new Tile(sameDiff(), iX, repeat).outputVariable(); - } - - public SDVariable tileBp(@NonNull SDVariable in, @NonNull SDVariable grad, @NonNull int[] repeat){ - return new TileBp(sameDiff, in, grad, repeat).outputVariable(); - } - - public SDVariable tile(@NonNull SDVariable iX, @NonNull SDVariable repeat) { - return new Tile(sameDiff(), iX, repeat).outputVariable(); - } - - public SDVariable tileBp(@NonNull SDVariable in, @NonNull SDVariable repeat, @NonNull SDVariable grad){ - return new TileBp(sameDiff, in, repeat, grad).outputVariable(); - } - - public SDVariable dropout(SDVariable input, double p) { - return new DropOutInverted(sameDiff(), input, p).outputVariable(); - } - - - public SDVariable sum(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Sum(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable sumBp(SDVariable i_x, SDVariable grad, boolean keepDims, int... dimensions) { - return new SumBp(sameDiff(), i_x, grad, keepDims, dimensions).outputVariable(); - } - - - public SDVariable prod(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Prod(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable prodBp(SDVariable preReduceInput, SDVariable grad, boolean keepDims, int... dimensions) { - return new ProdBp(sameDiff(), preReduceInput, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable mean(SDVariable in, boolean keepDims, int... dimensions) { - return new Mean(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable meanBp(SDVariable in, SDVariable grad, boolean keepDims, int... dimensions) { - return new MeanBp(sameDiff(), in, grad, keepDims, dimensions).outputVariable(); - } - - - public SDVariable std(SDVariable i_x, boolean biasCorrected, boolean keepDims, int... dimensions) { - return new StandardDeviation(sameDiff(), i_x, biasCorrected, keepDims, dimensions).outputVariable(); - } - - public SDVariable stdBp(SDVariable stdInput, SDVariable gradient, boolean biasCorrected, boolean keepDims, int... dimensions) { - return new StandardDeviationBp(sameDiff(), stdInput, gradient, biasCorrected, keepDims, dimensions).outputVariable(); - } - - - public SDVariable variance(SDVariable i_x, boolean biasCorrected, boolean keepDims, int... dimensions) { - return new Variance(sameDiff(), i_x, biasCorrected, keepDims, dimensions).outputVariable(); - } - - public SDVariable varianceBp(SDVariable stdInput, SDVariable gradient, boolean biasCorrected, boolean keepDims, int... dimensions) { - return new VarianceBp(sameDiff(), stdInput, gradient, biasCorrected, keepDims, dimensions).outputVariable(); - } - - public SDVariable standardize(SDVariable i_x, int... dimensions) { - return new Standardize(sameDiff(), i_x, dimensions).outputVariable(); - } - - public SDVariable standardizeBp(SDVariable stdInput, SDVariable gradient, int... dimensions) { - return new StandardizeBp(sameDiff(), stdInput, gradient, dimensions).outputVariable(); - } - - public SDVariable layerNorm(SDVariable input, SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { - return new LayerNorm(sameDiff(), input, gain, bias, channelsFirst, dimensions).outputVariable(); - } - - public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable bias, SDVariable gradient, boolean channelsFirst, int... dimensions) { - return new LayerNormBp(sameDiff(), input, gain, bias, gradient, channelsFirst, dimensions).outputVariables(); - } - - public SDVariable layerNorm(SDVariable input, SDVariable gain, boolean channelsFirst, int... dimensions) { - return new LayerNorm(sameDiff(), input, gain, channelsFirst, dimensions).outputVariable(); - } - - public SDVariable[] layerNormBp(SDVariable input, SDVariable gain, SDVariable gradient, boolean channelsFirst, int... dimensions) { - return new LayerNormBp(sameDiff(), input, gain, gradient, channelsFirst, dimensions).outputVariables(); - } - - public SDVariable squaredNorm(SDVariable input, boolean keepDims, int... dimensions) { - return new SquaredNorm(sameDiff(), input, keepDims, dimensions).outputVariable(); - } - - public SDVariable squaredNormBp(SDVariable preReduceInput, SDVariable gradient, boolean keepDims, int... dimensions) { - return new SquaredNormBp(sameDiff(), preReduceInput, gradient, keepDims, dimensions).outputVariable(); - } - - public SDVariable entropy(SDVariable in, int... dimensions) { - return new Entropy(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable logEntropy(SDVariable in, int... dimensions) { - return new LogEntropy(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable shannonEntropy(SDVariable in, int... dimensions){ - return new ShannonEntropy(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable countNonZero(SDVariable input, int... dimensions) { - return new CountNonZero(sameDiff(), input, dimensions).outputVariable(); - } - - public SDVariable countZero(SDVariable input, int... dimensions) { - return new CountZero(sameDiff(), input, dimensions).outputVariable(); - } - - public SDVariable zeroFraction(SDVariable input) { - return new ZeroFraction(sameDiff(), input).outputVariable(); - } - - public SDVariable scalarMax(SDVariable in, Number num) { - return new ScalarMax(sameDiff(), in, num).outputVariable(); - } - - public SDVariable scalarMin(SDVariable in, Number num) { - return new ScalarMin(sameDiff(), in, num).outputVariable(); - } - - public SDVariable scalarSet(SDVariable in, Number num) { - return new ScalarSet(sameDiff(), in, num).outputVariable(); - } - - public SDVariable scalarFloorMod(SDVariable in, Number num) { - return new ScalarFMod(sameDiff(), in, num).outputVariable(); - } - - public SDVariable max(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Max(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable max(SDVariable first, SDVariable second) { - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sameDiff(), first, second) - .outputVariable(); - } - - public SDVariable maxBp(SDVariable i_x, SDVariable grad, boolean keepDims, int... dimensions) { - return new MaxBp(sameDiff(), i_x, grad, keepDims, dimensions).outputVariable(); - } - - - public SDVariable min(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Min(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable minBp(SDVariable i_x, SDVariable grad, boolean keepDims, int... dimensions) { - return new MinBp(sameDiff(), i_x, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable min(SDVariable first, SDVariable second) { - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sameDiff(), first, second) - .outputVariable(); - } - - public SDVariable amax(SDVariable in, int... dimensions) { - return new AMax(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable amin(SDVariable in, int... dimensions) { - return new AMin(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable amean(SDVariable in, int... dimensions) { - return new AMean(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable asum(SDVariable in, int... dimensions) { - return new ASum(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) { - return new IMax(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) { - return new IMin(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) { - return new IAMax(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) { - return new IAMin(sameDiff(), in, keepDims, dimensions).outputVariable(); - } - - public SDVariable firstIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return new FirstIndex(sameDiff(), in, condition, keepDims, dimensions).outputVariable(); - } - - public SDVariable lastIndex(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return new LastIndex(sameDiff(), in, condition, keepDims, dimensions).outputVariable(); - } - - /** - * Returns a count of the number of elements that satisfy the condition - * - * @param in Input - * @param condition Condition - * @return Number of elements that the condition is satisfied for - */ - public SDVariable matchConditionCount(SDVariable in, Condition condition, boolean keepDims, int... dimensions) { - return new MatchCondition(sameDiff(), in, condition, keepDims, dimensions).outputVariable(); - } - - /** - * Returns a boolean mask of equal shape to the input, where the condition is satisfied - * - * @param in Input - * @param condition Condition - * @return Boolean mask - */ - public SDVariable matchCondition(SDVariable in, Condition condition) { - return new MatchConditionTransform(sameDiff(), in, condition).outputVariable(); - } - - public SDVariable cumsum(SDVariable in, boolean exclusive, boolean reverse, int... axis) { - return new CumSum(sameDiff(), in, exclusive, reverse, axis).outputVariable(); - } - - public SDVariable cumsumBp(SDVariable in, SDVariable grad, boolean exclusive, boolean reverse, int... axis) { - return new CumSumBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable(); - } - - public SDVariable cumprod(SDVariable in, boolean exclusive, boolean reverse, int... axis) { - return new CumProd(sameDiff(), in, exclusive, reverse, axis).outputVariable(); - } - - public SDVariable cumprodBp(SDVariable in, SDVariable grad, boolean exclusive, boolean reverse, int... axis) { - return new CumProdBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable(); - } - - public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) { - return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable(); - } - - public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad, boolean nchw) { - return new BiasAddGrad(sameDiff(), input, bias, grad, nchw).outputVariables(); - } - - public SDVariable norm1(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Norm1(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable norm1Bp(SDVariable preReduceIn, SDVariable grad, boolean keepDims, int... dimensions) { - return new Norm1Bp(sameDiff(), preReduceIn, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable norm2(SDVariable i_x, boolean keepDims, int... dimensions) { - return new Norm2(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable norm2Bp(SDVariable preReduceIn, SDVariable grad, boolean keepDims, int... dimensions) { - return new Norm2Bp(sameDiff(), preReduceIn, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable normmax(SDVariable i_x, boolean keepDims, int... dimensions) { - return new NormMax(sameDiff(), i_x, keepDims, dimensions).outputVariable(); - } - - public SDVariable normmaxBp(SDVariable preReduceIn, SDVariable grad, boolean keepDims, int... dimensions) { - return new NormMaxBp(sameDiff(), preReduceIn, grad, keepDims, dimensions).outputVariable(); - } - - public SDVariable reductionShape(SDVariable shape, SDVariable axis, boolean keepDim){ - return new ReductionShape(sameDiff(), shape, axis, keepDim).outputVariable(); - } - - /** - * Add 1s as required to the array make an array possible to be broadcast with the original (pre-reduce) array. - *

- * Example: if doing [a,b,c].sum(1), result is [a,c]. To 'undo' this in a way that can be auto-broadcast, - * we want to expand as required - i.e., [a,c] -> [a,1,c] which can be auto-broadcast with the original [a,b,c]. - * This is typically only used with reduction operations backprop. - * - * @param origRank Rank of the original array, before the reduction was executed - * @param reduceDims Dimensions that the original array was reduced from - * @param toExpand Array to add 1s to the shape to (such that it can be - * @return Reshaped array. - */ - public SDVariable reductionBroadcastableWithOrigShape(int origRank, int[] reduceDims, SDVariable toExpand) { - if (Shape.isWholeArray(origRank, reduceDims)) { - //Output is [1,1] which is already broadcastable - return toExpand; - } else if (origRank == 2 && reduceDims.length == 1) { - //In this case: [a,b] -> [1,b] or [a,b] -> [a,1] - //both are already broadcastable - return toExpand; - } else { - //Example: [a,b,c].sum(1) -> [a,c]... want [a,1,c] - for (int d : reduceDims) { - toExpand = sameDiff().expandDims(toExpand, d); - } - return toExpand; - } - } - - public SDVariable reductionBroadcastableWithOrigShape(SDVariable origInput, SDVariable axis, SDVariable toExpand) { - SDVariable shape = origInput.shape(); - SDVariable reduceShape = reductionShape(shape, axis, true); - SDVariable reshaped = toExpand.reshape(reduceShape); - return reshaped; - } - - - public SDVariable gradientBackwardsMarker(SDVariable iX) { - return new GradientBackwardsMarker(sameDiff(), iX, sameDiff.scalar(iX.name() + "-pairgrad", 1.0)).outputVariable(); - } - - public SDVariable abs(SDVariable iX) { - return new Abs(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable neg(SDVariable iX) { - return new Negative(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable cos(SDVariable iX) { - return new Cos(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable sin(SDVariable iX) { - return new Sin(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable tan(SDVariable iX) { - return new Tan(sameDiff(), iX, false).outputVariable(); - - } - - - public SDVariable permute(SDVariable iX, int... dimensions) { - return new Permute(sameDiff(), iX, dimensions).outputVariable(); - } - - public SDVariable permute(SDVariable in, SDVariable dimensions) { - return new Permute(sameDiff(), in, dimensions).outputVariable(); - } - - public SDVariable noop(SDVariable input) { - return new NoOp(sameDiff(), input).outputVariable(); - } - - public SDVariable identity(SDVariable input) { - return new Identity(sameDiff(), input).outputVariable(); - } - - public SDVariable all(SDVariable input, int... dimensions) { - return new All(sameDiff(), input, dimensions).outputVariable(); - } - - public SDVariable any(SDVariable input, int... dimensions) { - return new Any(sameDiff(), input, dimensions).outputVariable(); - } - - public SDVariable invertPermutation(SDVariable input, boolean inPlace) { - return new InvertPermutation(sameDiff(), input, inPlace).outputVariable(); - } - - public SDVariable transpose(SDVariable iX) { - return new Transpose(sameDiff(), iX).outputVariable(); - } - - - public SDVariable acos(SDVariable iX) { - return new ACos(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable asin(SDVariable iX) { - return new ASin(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable atan(SDVariable iX) { - return new ATan(sameDiff(), iX, false).outputVariable(); - - } - - public SDVariable atan2(SDVariable y, SDVariable x) { - return new ATan2(sameDiff(), y, x).outputVariable(); - } - - - public SDVariable cosh(SDVariable iX) { - return new Cosh(sameDiff(), iX, false).outputVariable(); - - } - - - public SDVariable sinh(SDVariable iX) { - return new Sinh(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable tanh(SDVariable iX) { - return new Tanh(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable tanhRational(SDVariable in) { - return new RationalTanh(sameDiff(), in, false).outputVariable(); - } - - public SDVariable tanhRectified(SDVariable in) { - return new RectifiedTanh(sameDiff(), in, false).outputVariable(); - } - - public SDVariable tanhDerivative(SDVariable iX, SDVariable wrt) { - return new org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative(sameDiff(), iX, wrt).outputVariable(); - } - - public SDVariable tanhRationalBp(SDVariable in, SDVariable epsilon) { - return new RationalTanhBp(sameDiff(), in, epsilon).outputVariable(); - } - - public SDVariable tanhRectifiedBp(SDVariable in, SDVariable epsilon) { - return new RectifiedTanhBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * Use {@link #tanhRationalBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable tanhRationalDerivative(SDVariable in) { - return new RationalTanhDerivative(sameDiff(), in, false).outputVariable(); - } - - /** - * Use {@link #tanhRectifiedBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable tanhRectifiedDerivative(SDVariable in) { - return new RectifiedTanhDerivative(sameDiff(), in, false).outputVariable(); - } - - public SDVariable step(SDVariable in, double cutoff) { - return new Step(sameDiff(), in, false, cutoff).outputVariable(); - } - - - public SDVariable acosh(SDVariable iX) { - return new ACosh(sameDiff(), iX).outputVariable(); - } - - - public SDVariable asinh(SDVariable iX) { - return new ASinh(sameDiff(), iX).outputVariable(); - } - - - public SDVariable atanh(SDVariable iX) { - return new ATanh(sameDiff(), iX).outputVariable(); - } - - - public SDVariable exp(SDVariable iX) { - return new Exp(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable expm1(SDVariable iX) { - return new Expm1(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable rsqrt(SDVariable iX) { - return new RSqrt(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable log(SDVariable iX) { - return new Log(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable log(SDVariable in, double base) { - return new LogX(sameDiff(), in, base).outputVariable(); - } - - public SDVariable log1p(SDVariable iX) { - return new Log1p(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable isFinite(SDVariable ix) { - return new IsFinite(sameDiff(), ix, false).outputVariable(); - } - - public SDVariable isInfinite(SDVariable ix) { - return new IsInf(sameDiff(), ix, false).outputVariable(); - } - - public SDVariable isNaN(SDVariable ix) { - return new IsNaN(sameDiff(), ix, false).outputVariable(); - } - - public SDVariable isMax(SDVariable ix) { - return new IsMax(sameDiff(), ix).outputVariable(); - } - - public SDVariable replaceWhere(SDVariable to, SDVariable from, Condition condition) { - return new CompareAndReplace(sameDiff(), to, from, condition).outputVariable(); - } - - public SDVariable replaceWhere(SDVariable to, Number set, Condition condition) { - return new CompareAndSet(sameDiff(), to, set, condition).outputVariable(); - } - - public SDVariable round(SDVariable ix) { - return new Round(sameDiff(), ix, false).outputVariable(); - } - - public SDVariable or(SDVariable iX, SDVariable i_y) { - return new Or(sameDiff(), iX, i_y).outputVariable(); - } - - public SDVariable and(SDVariable ix, SDVariable iy) { - return new And(sameDiff(), ix, iy).outputVariable(); - } - - public SDVariable xor(SDVariable ix, SDVariable iy) { - return new Xor(sameDiff(), ix, iy).outputVariable(); - } - - public SDVariable shift(SDVariable ix, SDVariable shift) { - return new ShiftBits(sameDiff(), ix, shift).outputVariable(); - } - - public SDVariable rshift(SDVariable ix, SDVariable shift) { - return new RShiftBits(sameDiff(), ix, shift).outputVariable(); - } - - public SDVariable rotl(SDVariable ix, SDVariable shift) { - return new CyclicShiftBits(sameDiff(), ix, shift).outputVariable(); - } - - public SDVariable rotr(SDVariable ix, SDVariable shift) { - return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable(); - } - - public SDVariable bitwiseHammingDist(SDVariable x, SDVariable y) { - return new BitsHammingDistance(sameDiff(), x, y).outputVariable(); - } - - public SDVariable bitwiseAnd(SDVariable x, SDVariable y){ - return new BitwiseAnd(sameDiff(), x, y).outputVariable(); - } - - public SDVariable bitwiseOr(SDVariable x, SDVariable y){ - return new BitwiseOr(sameDiff(), x, y).outputVariable(); - } - - public SDVariable bitwiseXor(SDVariable x, SDVariable y){ - return new BitwiseXor(sameDiff(), x, y).outputVariable(); - } - - public SDVariable eq(SDVariable iX, SDVariable i_y) { - return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); - } - - - public SDVariable neq(SDVariable iX, double i_y) { - return new ScalarNotEquals(sameDiff(), iX, i_y).outputVariable(); - } - - - public SDVariable neqi(SDVariable iX, double i_y) { - return new ScalarNotEquals(sameDiff(), iX, i_y, true).outputVariable(); - } - - - public SDVariable neqi(SDVariable iX, SDVariable i_y) { - return new NotEqualTo(sameDiff(), new SDVariable[]{iX, i_y}, true).outputVariable(); - } - - public SDVariable neq(SDVariable iX, SDVariable i_y) { - return new NotEqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); - } - - public SDVariable pow(SDVariable iX, double i_y) { - return new Pow(sameDiff(), iX, false, i_y).outputVariable(); - } - - public SDVariable pow(SDVariable x, SDVariable y){ - return new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(sameDiff(), x, y).outputVariable(); - } - - public SDVariable sqrt(SDVariable iX) { - return new Sqrt(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable square(SDVariable iX) { - return new Square(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable cube(SDVariable iX) { - return new Cube(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable cubeBp(SDVariable in, SDVariable epsilon) { - return new CubeBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * @deprecated Use {@link #cubeBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable cubeDerivative(SDVariable iX) { - return new CubeDerivative(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable floor(SDVariable iX) { - return new Floor(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable floorDiv(SDVariable x, SDVariable y) { - return new FloorDivOp(sameDiff(), x, y).outputVariable(); - } - - public List floorDivBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new FloorDivBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - public SDVariable floorMod(SDVariable x, SDVariable y) { - return new FloorModOp(sameDiff(), x, y).outputVariable(); - } - - public List floorModBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new FloorModBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - public SDVariable ceil(SDVariable x) { - return new Ceil(sameDiff(), x).outputVariable(); - } - - public SDVariable clipByValue(SDVariable x, double clipValueMin, double clipValueMax) { - return new ClipByValue(sameDiff(), x, clipValueMin, clipValueMax).outputVariable(); - } - - public SDVariable clipByNorm(SDVariable x, double clipValue) { - return new ClipByNorm(sameDiff(), x, clipValue).outputVariable(); - } - - public SDVariable clipByNorm(SDVariable x, double clipValue, int... dimensions) { - return new ClipByNorm(sameDiff(), x, clipValue, dimensions).outputVariable(); - } - - public SDVariable relu(SDVariable iX, double cutoff) { - return new RectifiedLinear(sameDiff(), iX, false, cutoff).outputVariable(); - } - - public SDVariable reluDerivative(SDVariable input, SDVariable grad){ - return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable(); - } - - public SDVariable thresholdRelu(SDVariable in, SDVariable epsilon, double cutoff){ - return new ThresholdRelu(sameDiff(), in, cutoff).outputVariable(); - } - - public SDVariable thresholdReluBp(SDVariable in, SDVariable epsilon, double cutoff){ - return new ThresholdReluBp(sameDiff(), in, epsilon, cutoff).outputVariable(); - } - - public SDVariable relu6(SDVariable iX, double cutoff) { - return new Relu6(sameDiff(), iX, false, cutoff).outputVariable(); - } - - public SDVariable relu6Derivative(SDVariable iX, SDVariable wrt, double cutoff) { - return new Relu6Derivative(sameDiff(), iX, wrt, cutoff).outputVariable(); - } - - public SDVariable softmax(SDVariable iX) { - return new SoftMax(sameDiff(), new SDVariable[]{iX}).outputVariable(); - } - - public SDVariable softmax(SDVariable iX, int dimension) { - return new SoftMax(sameDiff(), new SDVariable[]{iX}, dimension).outputVariable(); - } - - - public SDVariable hardTanh(SDVariable iX) { - return new HardTanh(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable hardTanhBp(SDVariable in, SDVariable epsilon) { - return new HardTanhBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * @deprecated Use {@link #hardTanhBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable hardTanhDerivative(SDVariable iX) { - return new HardTanhDerivative(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable hardSigmoid(SDVariable in) { - return new HardSigmoid(sameDiff(), in, false).outputVariable(); - } - - public SDVariable hardSigmoidBp(SDVariable in, SDVariable epsilon){ - return new HardSigmoidBp(sameDiff(), in, epsilon).outputVariable(); - } - - public SDVariable sigmoid(SDVariable iX) { - return new Sigmoid(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable sigmoidDerivative(SDVariable iX, SDVariable wrt) { - return new SigmoidDerivative(sameDiff(), iX, wrt).outputVariable(); - } - - - public SDVariable logSigmoid(SDVariable iX) { - return new LogSigmoid(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable powDerivative(SDVariable iX, double pow) { - return new PowDerivative(sameDiff(), iX, false, pow).outputVariable(); - } - - public SDVariable[] powBp(SDVariable x, SDVariable pow, SDVariable gradient) { - return new PowBp(sameDiff(), x, pow, gradient).outputVariables(); - } - - public SDVariable mishDerivative(SDVariable iX) { - return new MishDerivative(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable swish(SDVariable iX) { - return new Swish(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable swishDerivative(SDVariable iX) { - return new SwishDerivative(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable gelu(SDVariable iX, boolean precise) { - if (precise) - return new PreciseGELU(sameDiff(), iX, false, precise).outputVariable(); - else - return new GELU(sameDiff(), iX, false, precise).outputVariable(); - } - - public SDVariable geluDerivative(SDVariable iX, boolean precise) { - if (precise) - return new PreciseGELUDerivative(sameDiff(), iX, false, precise).outputVariable(); - else - return new GELUDerivative(sameDiff(), iX, false).outputVariable(); - } - - public SDVariable sign(SDVariable iX) { - return new Sign(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable expandDims(SDVariable iX, int axis) { - return new ExpandDims(sameDiff(), new SDVariable[]{iX}, axis).outputVariable(); - } - - public SDVariable squeeze(SDVariable iX, int... axis) { - return new Squeeze(sameDiff(), iX, axis).outputVariable(); - } - - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, DataType dataType) { - return new ConfusionMatrix(sameDiff(), labels, pred, dataType).outputVariable(); - } - - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses) { - return new ConfusionMatrix(sameDiff(), labels, pred, numClasses).outputVariable(); - } - - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, SDVariable weights) { - return new ConfusionMatrix(sameDiff(), labels, pred, weights).outputVariable(); - } - - public SDVariable confusionMatrix(SDVariable labels, SDVariable pred, Integer numClasses, SDVariable weights) { - return new ConfusionMatrix(sameDiff(), labels, pred, numClasses, weights).outputVariable(); - } - - public SDVariable matrixDeterminant(SDVariable in){ - return new MatrixDeterminant(sameDiff(), in, false).outputVariable(); - } - - public SDVariable matrixInverse(SDVariable in){ - return new MatrixInverse(sameDiff(), in, false).outputVariable(); - } - - public SDVariable onehot(SDVariable indices, int depth, int axis, double on, double off, DataType dataType) { - return new OneHot(sameDiff(), indices, depth, axis, on, off, dataType).outputVariable(); - } - - public SDVariable onehot(SDVariable indices, int depth) { - return new OneHot(sameDiff(), indices, depth).outputVariable(); - } - - public SDVariable reciprocal(SDVariable a) { - return new Reciprocal(sameDiff(), a).outputVariable(); - } - - - public SDVariable repeat(SDVariable iX, int axis) { - return new Repeat(sameDiff(), new SDVariable[]{iX}, axis).outputVariable(); - - } - - public SDVariable stack(SDVariable[] values, int axis) { - return new Stack(sameDiff(), values, axis).outputVariable(); - } - - public SDVariable parallel_stack(SDVariable[] values) { - return new ParallelStack(sameDiff(), values).outputVariable(); - } - - public SDVariable[] unstack(SDVariable value, int axis) { - return new Unstack(sameDiff(), value, axis).outputVariables(); - } - - public SDVariable[] unstack(SDVariable value, int axis, int num) { - return new Unstack(sameDiff(), value, axis, num).outputVariables(); - } - - public SDVariable assign(SDVariable x, SDVariable y) { - return new Assign(sameDiff(), x, y).outputVariable(); - } - - public SDVariable assign(SDVariable x, Number num) { - return new ScalarSet(sameDiff(), x, num).outputVariable(); - } - - - public SDVariable softsign(SDVariable iX) { - return new SoftSign(sameDiff(), iX, false).outputVariable(); - - } - - public SDVariable softsignBp(SDVariable in, SDVariable epsilon) { - return new SoftSignBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * @deprecated Use {@link #softsignBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable softsignDerivative(SDVariable iX) { - return new SoftSignDerivative(sameDiff(), iX, false).outputVariable(); - } - - - public SDVariable softplus(SDVariable iX) { - return new SoftPlus(sameDiff(), iX, false).outputVariable(); - - } - - - public SDVariable elu(SDVariable iX) { - return new ELU(sameDiff(), iX).outputVariable(); - - } - - public SDVariable eluBp(SDVariable in, SDVariable epsilon, double alpha) { - return new EluBp(sameDiff(), in, epsilon, alpha).outputVariable(); - } - - - public SDVariable leakyRelu(SDVariable iX, double alpha) { - return new LeakyReLU(sameDiff(), iX, false, alpha).outputVariable(); - - } - - public SDVariable leakyReluBp(SDVariable in, SDVariable epsilon, double cutoff) { - return new LeakyReLUBp(sameDiff(), in, epsilon, cutoff).outputVariable(); - } - - /** - * @deprecated Use {@link #leakyReluBp(SDVariable, SDVariable, double)} - */ - @Deprecated - public SDVariable leakyReluDerivative(SDVariable iX, double cutoff) { - return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable(); - } - - public SDVariable prelu(SDVariable x, SDVariable alpha, int... sharedAxes){ - return new PRelu(sameDiff(), x, alpha, sharedAxes).outputVariable(); - } - - public SDVariable[] preluBp(SDVariable in, SDVariable alpha, SDVariable epsilon, int... sharedAxes){ - return new PReluBp(sameDiff(), in, alpha, epsilon, sharedAxes).outputVariables(); - } - - public SDVariable reshape(SDVariable iX, int[] shape) { - return new Reshape(sameDiff(), iX, ArrayUtil.toLongArray(shape)).outputVariable(); - } - - public SDVariable reshape(SDVariable iX, long[] shape) { - return new Reshape(sameDiff(), iX, shape).outputVariable(); - } - - public SDVariable reshape(SDVariable iX, SDVariable shape) { - return new Reshape(sameDiff(), iX, shape).outputVariable(); - } - - public SDVariable reverse(SDVariable x, int... dimensions) { - return new Reverse(sameDiff(), x, dimensions).outputVariable(); - } - - public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths, int seq_dim, int batch_dim) { - return new ReverseSequence(sameDiff(), x, seq_lengths, seq_dim, batch_dim).outputVariable(); - } - - public SDVariable reverseSequence(SDVariable x, SDVariable seq_lengths) { - return new ReverseSequence(sameDiff(), x, seq_lengths).outputVariable(); - } - - public SDVariable sequenceMask(SDVariable lengths, SDVariable maxLen, DataType dataType) { - return new SequenceMask(sameDiff(), lengths, maxLen, dataType).outputVariable(); - } - - public SDVariable sequenceMask(SDVariable lengths, int maxLen, DataType dataType) { - return new SequenceMask(sameDiff(), lengths, maxLen, dataType).outputVariable(); - } - - public SDVariable sequenceMask(SDVariable lengths, DataType dataType) { - return new SequenceMask(sameDiff(), lengths, dataType).outputVariable(); - } - - public SDVariable concat(int dimension, SDVariable... inputs) { - return new Concat(sameDiff(), dimension, inputs).outputVariable(); - } - - public SDVariable fill(SDVariable shape, DataType dataType, double value) { - return new Fill(sameDiff(), shape, dataType, value).outputVariable(); - } - - public SDVariable dot(SDVariable x, SDVariable y, int... dimensions) { - return new Dot(sameDiff(), x, y, dimensions).outputVariable(); - } - - public SDVariable[] dotBp(SDVariable in1, SDVariable in2, SDVariable grad, boolean keepDims, int... dimensions) { - return new DotBp(sameDiff(), in1, in2, grad, keepDims, dimensions).outputVariables(); - } - - public SDVariable cosineSimilarity(SDVariable iX, SDVariable i_y, int... dimensions) { - return new CosineSimilarity(sameDiff(), iX, i_y, dimensions).outputVariable(); - } - - public SDVariable cosineDistance(SDVariable ix, SDVariable iy, int... dimensions) { - return new CosineDistance(sameDiff(), ix, iy, dimensions).outputVariable(); - } - - - public SDVariable euclideanDistance(SDVariable iX, SDVariable i_y, int... dimensions) { - return new EuclideanDistance(sameDiff(), iX, i_y, dimensions).outputVariable(); - } - - - public SDVariable manhattanDistance(SDVariable iX, SDVariable i_y, int... dimensions) { - return new ManhattanDistance(sameDiff(), iX, i_y, dimensions).outputVariable(); - } - - public SDVariable hammingDistance(SDVariable ix, SDVariable iy, int... dimensions) { - return new HammingDistance(sameDiff(), ix, iy, dimensions).outputVariable(); - } - - public SDVariable jaccardDistance(SDVariable ix, SDVariable iy, int... dimensions) { - return new JaccardDistance(sameDiff(), ix, iy, dimensions).outputVariable(); - } - - public SDVariable weightedCrossEntropyWithLogits(SDVariable targets, SDVariable inputs, SDVariable weights) { - return new WeightedCrossEntropyLoss(sameDiff(), targets, inputs, weights).outputVariable(); - } - - public SDVariable lossL2(SDVariable var){ - return new L2Loss(sameDiff(), var).outputVariable(); - } - - public SDVariable lossAbsoluteDifference(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new AbsoluteDifferenceLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossAbsoluteDifferenceBP(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new AbsoluteDifferenceLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossCosineDistance(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, int dimension){ - return new CosineDistanceLoss(sameDiff(), lossReduce, predictions, weights, label, dimension).outputVariable(); - } - - public SDVariable[] lossCosineDistanceBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, int dimension){ - return new CosineDistanceLossBp(sameDiff(), lossReduce, predictions, weights, label, dimension).outputVariables(); - } - - public SDVariable lossHinge(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new HingeLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossHingeBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new HingeLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossHuber(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double delta){ - return new HuberLoss(sameDiff(), lossReduce, predictions, weights, label, delta).outputVariable(); - } - - public SDVariable[] lossHuberBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double delta){ - return new HuberLossBp(sameDiff(), lossReduce, predictions, weights, label, delta).outputVariables(); - } - - public SDVariable lossLog(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double epsilon){ - return new LogLoss(sameDiff(), lossReduce, predictions, weights, label, epsilon).outputVariable(); - } - - public SDVariable[] lossLogBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce, double epsilon){ - return new LogLossBp(sameDiff(), lossReduce, predictions, weights, label, epsilon).outputVariables(); - } - - public SDVariable lossLogPoisson(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new LogPoissonLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossLogPoissonBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new LogPoissonLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossLogPoissonFull(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new LogPoissonLoss(sameDiff(), lossReduce, predictions, weights, label, true).outputVariable(); - } - - public SDVariable[] lossLogPoissonFullBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new LogPoissonLossBp(sameDiff(), lossReduce, predictions, weights, label, true).outputVariables(); - } - - public SDVariable lossMeanPairwiseSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new MeanPairwiseSquaredErrorLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossMeanPairwiseSquaredErrorBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new MeanPairwiseSquaredErrorLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossMeanSquaredError(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new MeanSquaredErrorLoss(sameDiff(), lossReduce, predictions, weights, label).outputVariable(); - } - - public SDVariable[] lossMeanSquaredErrorBp(SDVariable label, SDVariable predictions, SDVariable weights, LossReduce lossReduce){ - return new MeanSquaredErrorLossBp(sameDiff(), lossReduce, predictions, weights, label).outputVariables(); - } - - public SDVariable lossSigmoidCrossEntropy(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) { - return new SigmoidCrossEntropyLoss(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariable(); - } - - public SDVariable[] lossSigmoidCrossEntropyBp(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) { - return new SigmoidCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables(); - } - - public SDVariable lossSoftmaxCrossEntropy(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) { - return new SoftmaxCrossEntropyLoss(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariable(); - } - - public SDVariable[] lossSoftmaxCrossEntropyBp(SDVariable labels, SDVariable logits, SDVariable weights, LossReduce lossReduce, double labelSmoothing) { - return new SoftmaxCrossEntropyLossBp(sameDiff(), lossReduce, logits, weights, labels, labelSmoothing).outputVariables(); - } - - public SDVariable lossSoftmaxCrossEntropyWithLogits(SDVariable labels, SDVariable logits, int classDim) { - return new SoftmaxCrossEntropyWithLogitsLoss(sameDiff(), logits, labels, classDim).outputVariable(); - } - - public SDVariable[] lossSoftmaxCrossEntropyWithLogitsBp(SDVariable labels, SDVariable logits, int classDim) { - return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff(), logits, labels, classDim).outputVariables(); - } - - public SDVariable lossSparseSoftmaxCrossEntropy(SDVariable logits, SDVariable labels){ - return new SparseSoftmaxCrossEntropyLossWithLogits(sameDiff(), logits, labels).outputVariable(); - } - - public SDVariable[] lossSparseSoftmaxCrossEntropyBp(SDVariable logits, SDVariable labels){ - return new SparseSoftmaxCrossEntropyLossWithLogitsBp(sameDiff(), logits, labels).outputVariables(); - } - - - public SDVariable xwPlusB(SDVariable input, SDVariable weights, SDVariable bias) { - return new XwPlusB(sameDiff(), input, weights, bias).outputVariable(); - } - - public SDVariable reluLayer(SDVariable input, SDVariable weights, SDVariable bias) { - return new ReluLayer(sameDiff(), input, weights, bias).outputVariable(); - } - - public SDVariable mmul(SDVariable x, - SDVariable y, - MMulTranspose mMulTranspose) { - validateDifferentialFunctionsameDiff(x); - validateDifferentialFunctionsameDiff(y); - return new Mmul(sameDiff(), x, y, mMulTranspose).outputVariable(); - } - - - public SDVariable mmul(SDVariable x, - SDVariable y) { - return mmul(x, y, MMulTranspose.allFalse()); - } - - public List mmulBp(SDVariable x, SDVariable y, SDVariable eps, MMulTranspose mt) { - return Arrays.asList(new MmulBp(sameDiff(), x, y, eps, mt).outputVariables()); - } - - public SDVariable[] batchMmul(SDVariable[] matricesA, - SDVariable[] matricesB) { - return batchMmul(matricesA, matricesB, false, false); - } - - - public SDVariable[] batchMmul(SDVariable[] matricesA, - SDVariable[] matricesB, - boolean transposeA, - boolean transposeB) { - return batchMmul(ArrayUtils.addAll(matricesA, matricesB), transposeA, transposeB); - } - - - public SDVariable[] batchMmul(SDVariable[] matrices, - boolean transposeA, - boolean transposeB) { - return new BatchMmul(sameDiff(), matrices, transposeA, transposeB).outputVariables(); - } - - - public SDVariable tensorMmul(SDVariable x, - SDVariable y, - int[][] dimensions) { - validateDifferentialFunctionsameDiff(x); - validateDifferentialFunctionsameDiff(y); - return new TensorMmul(sameDiff(), x, y, dimensions).outputVariable(); - } - - public SDVariable dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled) { - return new DotProductAttention(sameDiff(), queries, keys, values, mask, scaled, false).outputVariable(); - } - - public List dotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable mask, boolean scaled, boolean withWeights) { - return Arrays.asList(new DotProductAttention(sameDiff(), queries, keys, values, mask, scaled, withWeights).outputVariables()); - } - - public List dotProductAttentionBp(SDVariable queries, SDVariable keys, SDVariable values, SDVariable gradient, SDVariable mask, boolean scaled) { - return Arrays.asList(new DotProductAttentionBp(sameDiff(), queries, keys, values, gradient, mask, scaled).outputVariables()); - } - - public SDVariable multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values, SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled) { - return new MultiHeadDotProductAttention(sameDiff(), queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false).outputVariable(); - } - - public List multiHeadDotProductAttention(SDVariable queries, SDVariable keys, SDVariable values,SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable mask, boolean scaled, boolean withWeights) { - return Arrays.asList(new MultiHeadDotProductAttention(sameDiff(), queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, withWeights).outputVariables()); - } - - public List multiHeadDotProductAttentionBp(SDVariable queries, SDVariable keys, SDVariable values,SDVariable Wq, SDVariable Wk, SDVariable Wv, SDVariable Wo, SDVariable gradient, SDVariable mask, boolean scaled) { - return Arrays.asList(new MultiHeadDotProductAttentionBp(sameDiff(), queries, keys, values, Wq, Wk, Wv, Wo, gradient, mask, scaled).outputVariables()); - } - - public SDVariable softmaxDerivative(SDVariable functionInput, SDVariable wrt, Integer dimension) { - validateDifferentialFunctionsameDiff(functionInput); - return new SoftmaxBp(sameDiff(), functionInput, wrt, dimension).outputVariable(); - } - - - public SDVariable logSoftmax(SDVariable i_v) { - validateDifferentialFunctionsameDiff(i_v); - return new LogSoftMax(sameDiff(), i_v).outputVariable(); - - } - - - public SDVariable logSoftmax(SDVariable i_v, int dimension) { - validateDifferentialFunctionsameDiff(i_v); - return new LogSoftMax(sameDiff(), i_v, dimension).outputVariable(); - - } - - - public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt) { - validateDifferentialFunctionsameDiff(arg); - return new LogSoftMaxDerivative(sameDiff(), arg, wrt).outputVariable(); - } - - - public SDVariable logSoftmaxDerivative(SDVariable arg, SDVariable wrt, int dimension) { - validateDifferentialFunctionsameDiff(arg); - return new LogSoftMaxDerivative(sameDiff(), arg, wrt, dimension).outputVariable(); - } - - public SDVariable logSumExp(SDVariable arg, boolean keepDims, int... dimension) { - return new LogSumExp(sameDiff(), arg, keepDims, dimension).outputVariable(); - } - - - public SDVariable selu(SDVariable arg) { - validateDifferentialFunctionsameDiff(arg); - return new SELU(sameDiff(), arg, false).outputVariable(); - } - - public SDVariable seluBp(SDVariable in, SDVariable epsilon) { - validateDifferentialFunctionsameDiff(in); - return new SeluBp(sameDiff(), in, epsilon).outputVariable(); - } - - /** - * @deprecated Use {@link #seluBp(SDVariable, SDVariable)} - */ - @Deprecated - public SDVariable seluDerivative(SDVariable arg) { - validateDifferentialFunctionsameDiff(arg); - return new SELUDerivative(sameDiff(), arg, false).outputVariable(); - } - - - public SDVariable rsub(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new RSubOp(sameDiff(), differentialFunction, i_v).outputVariable(); - } - - public List rsubBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new RSubBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable rdiv(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new RDivOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public List rdivBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new RDivBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable rdivi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new RDivOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } - - - public SDVariable rsubi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new RSubOp(sameDiff(), differentialFunction, i_v).outputVariable(); - } - - public SDVariable add(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new AddOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - - } - - public SDVariable mergeAdd(SDVariable... differentialFunctions) { - for (SDVariable df : differentialFunctions) - validateDifferentialFunctionsameDiff(df); - return new MergeAddOp(sameDiff(), differentialFunctions, false).outputVariable(); - } - - public SDVariable mergeMax(SDVariable... differentialFunctions) { - for (SDVariable df : differentialFunctions) - validateDifferentialFunctionsameDiff(df); - return new MergeMax(sameDiff(), differentialFunctions).outputVariable(); - } - - public SDVariable mergeAvg(SDVariable... differentialFunctions) { - for (SDVariable df : differentialFunctions) - validateDifferentialFunctionsameDiff(df); - return new MergeAvg(sameDiff(), differentialFunctions).outputVariable(); - } - - public SDVariable diag(SDVariable sdVariable) { - validateDifferentialFunctionsameDiff(sdVariable); - return new Diag(sameDiff(), new SDVariable[]{sdVariable}, false).outputVariable(); - } - - public SDVariable diagPart(SDVariable sdVariable) { - validateDifferentialFunctionsameDiff(sdVariable); - return new DiagPart(sameDiff(), new SDVariable[]{sdVariable}, false).outputVariable(); - } - - public SDVariable setDiag(SDVariable in, SDVariable diag) { - return new MatrixSetDiag(sameDiff(), in, diag, false).outputVariable(); - } - - - public SDVariable batchToSpace(SDVariable differentialFunction, int[] blocks, int[][] crops) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new BatchToSpace(sameDiff(), new SDVariable[]{differentialFunction}, blocks, crops, false) - .outputVariable(); - } - - public SDVariable spaceToBatch(SDVariable differentialFunction, int[] blocks, int[][] padding) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SpaceToBatch(sameDiff(), new SDVariable[]{differentialFunction}, blocks, padding, false) - .outputVariable(); - } - - public SDVariable depthToSpace(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new DepthToSpace(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat) - .outputVariable(); - } - - public SDVariable spaceToDepth(SDVariable differentialFunction, int blocksSize, DataFormat dataFormat) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SpaceToDepth(sameDiff(), new SDVariable[]{differentialFunction}, blocksSize, dataFormat) - .outputVariable(); - } - - public SDVariable[] dynamicPartition(SDVariable differentialFunction, SDVariable partitions, int numPartitions) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new DynamicPartition(sameDiff(), differentialFunction, partitions, numPartitions) - .outputVariables(); - } - - public SDVariable[] dynamicPartitionBp(SDVariable input, SDVariable partitions, SDVariable[] grads, int numPartitions){ - return new DynamicPartitionBp(sameDiff(), input, partitions, grads, numPartitions).outputVariables(); - } - - public SDVariable dynamicStitch(SDVariable[] indices, SDVariable[] differentialFunctions) { - for (SDVariable df : differentialFunctions) - validateDifferentialFunctionsameDiff(df); - - return new DynamicStitch(sameDiff(), indices, differentialFunctions).outputVariable(); - } - - public SDVariable segmentMax(SDVariable data, SDVariable segmentIds){ - return new SegmentMax(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentMaxBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentMaxBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - public SDVariable segmentMin(SDVariable data, SDVariable segmentIds){ - return new SegmentMin(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentMinBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentMinBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - public SDVariable segmentMean(SDVariable data, SDVariable segmentIds){ - return new SegmentMean(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentMeanBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentMeanBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - public SDVariable segmentProd(SDVariable data, SDVariable segmentIds){ - return new SegmentProd(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentProdBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentProdBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - public SDVariable segmentSum(SDVariable data, SDVariable segmentIds){ - return new SegmentSum(sameDiff(), data, segmentIds).outputVariable(); - } - - public SDVariable[] segmentSumBp(SDVariable data, SDVariable segmentIds, SDVariable gradient){ - return new SegmentSumBp(sameDiff(), data, segmentIds, gradient).outputVariables(); - } - - - public SDVariable unsortedSegmentMax(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentMax(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentMaxBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentMaxBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentMin(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentMin(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentMinBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentMinBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentMean(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentMean(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentMeanBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentMeanBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentProd(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentProd(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentProdBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentProdBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentSum(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentSum(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentSumBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentSumBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - public SDVariable unsortedSegmentSqrtN(SDVariable data, SDVariable segmentIds, int numSegments){ - return new UnsortedSegmentSqrtN(sameDiff(), data, segmentIds, numSegments).outputVariable(); - } - - public SDVariable[] unsortedSegmentSqrtNBp(SDVariable data, SDVariable segmentIds, SDVariable gradient, int numSegments){ - return new UnsortedSegmentSqrtNBp(sameDiff(), data, segmentIds, gradient, numSegments).outputVariables(); - } - - - - - public SDVariable dilation2D(SDVariable df, SDVariable weights, int[] strides, - int[] rates, boolean isSameMode) { - validateDifferentialFunctionsameDiff(df); - return new Dilation2D(sameDiff(), new SDVariable[]{df, weights}, strides, rates, isSameMode, false) - .outputVariable(); - } - - public SDVariable shape(SDVariable df) { - validateDifferentialFunctionsameDiff(df); - return new org.nd4j.linalg.api.ops.impl.shape.Shape(sameDiff(), df, false).outputVariable(); - } - - public SDVariable size(SDVariable in) { - return new Size(sameDiff(), in).outputVariable(); - } - - public SDVariable sizeAt(SDVariable in, int dimension){ - return new SizeAt(sameDiff(), in, dimension).outputVariable(); - } - - public SDVariable rank(SDVariable df) { - return new Rank(sameDiff(), df, false).outputVariable(); - } - - public SDVariable gather(SDVariable df, int[] indices, int axis) { - validateDifferentialFunctionsameDiff(df); - return new Gather(sameDiff(), df, indices, axis, false).outputVariable(); - } - - public SDVariable gather(SDVariable df, SDVariable indices, int axis) { - validateDifferentialFunctionsameDiff(df); - return new Gather(sameDiff(), df, indices, axis, false).outputVariable(); - } - - public SDVariable gatherNd(SDVariable df, SDVariable indices) { - validateDifferentialFunctionsameDiff(df); - return new GatherNd(sameDiff(), df, indices).outputVariable(); - } - - public SDVariable trace(SDVariable in){ - return new Trace(sameDiff(), in).outputVariable(); - } - - public SDVariable cross(SDVariable a, SDVariable b) { - validateDifferentialFunctionsameDiff(a); - return new Cross(sameDiff(), new SDVariable[]{a, b}).outputVariable(); - } - - public SDVariable erf(SDVariable differentialFunction) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new Erf(sameDiff(), differentialFunction, false).outputVariable(); - } - - public SDVariable erfc(SDVariable differentialFunction) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new Erfc(sameDiff(), differentialFunction, false).outputVariable(); - } - - public SDVariable addi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new AddOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } - - public List addBp(SDVariable x, SDVariable y, SDVariable grad) { - SDVariable[] ret = new AddBpOp(sameDiff(), x, y, grad).outputVariables(); - return Arrays.asList(ret); - } - - - public SDVariable sub(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SubOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public SDVariable squaredDifference(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SquaredDifferenceOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false) - .outputVariable(); - } - - - public List subBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new SubBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable subi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new SubOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - - } - - - public SDVariable mul(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new MulOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public List mulBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new MulBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - public List modBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new ModBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new MulOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } - - public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ModOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public SDVariable div(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new DivOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); - } - - public SDVariable truncatedDiv(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new TruncateDivOp(sameDiff(), differentialFunction, i_v, false).outputVariable(); - } - - public List divBp(SDVariable x, SDVariable y, SDVariable grad) { - return Arrays.asList(new DivBpOp(sameDiff(), x, y, grad).outputVariables()); - } - - - public SDVariable divi(SDVariable differentialFunction, SDVariable i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new DivOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, true).outputVariable(); - } - - - public SDVariable rsub(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarReverseSubtraction(sameDiff(), differentialFunction, i_v).outputVariable(); - - } - - - public SDVariable rdiv(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarReverseDivision(sameDiff(), differentialFunction, i_v).outputVariable(); - - } - - - public SDVariable rdivi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarReverseDivision(sameDiff(), differentialFunction, i_v, true).outputVariable(); - } - - - public SDVariable rsubi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarReverseSubtraction(sameDiff(), differentialFunction, i_v, true).outputVariable(); - - } - - - public SDVariable add(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarAdd(sameDiff(), differentialFunction, i_v, false).outputVariable(); - } - - - public SDVariable addi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarAdd(sameDiff(), differentialFunction, i_v, true).outputVariable(); - } - - - public SDVariable sub(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarSubtraction(sameDiff(), differentialFunction, i_v).outputVariable(); - } - - - public SDVariable subi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarSubtraction(sameDiff(), differentialFunction, i_v, true).outputVariable(); - - } - - - public SDVariable mul(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarMultiplication(sameDiff(), differentialFunction, i_v).outputVariable(); - - } - - - public SDVariable muli(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarMultiplication(sameDiff(), differentialFunction, i_v, true).outputVariable(); - - } - - - public SDVariable div(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarDivision(sameDiff(), differentialFunction, i_v).outputVariable(); - } - - - public SDVariable divi(SDVariable differentialFunction, double i_v) { - validateDifferentialFunctionsameDiff(differentialFunction); - return new ScalarDivision(sameDiff(), differentialFunction, i_v, true).outputVariable(); - } - - - public SDVariable gt(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new GreaterThan(sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable(); - } - - - public SDVariable lt(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new LessThan(sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable(); - } - - - public SDVariable gti(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new GreaterThan(sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable(); - } - - - public SDVariable lti(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new LessThan(sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable(); - } - - - public SDVariable gte(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new GreaterThanOrEqual(sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable(); - } - - - public SDVariable lte(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new LessThanOrEqual(sameDiff(), new SDVariable[]{functionInput, functionInput1}, false).outputVariable(); - } - - - public SDVariable gtei(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new GreaterThanOrEqual(sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable(); - } - - - public SDVariable ltOrEqi(SDVariable functionInput, SDVariable functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - validateDifferentialFunctionsameDiff(functionInput1); - return new LessThanOrEqual(sameDiff(), new SDVariable[]{functionInput, functionInput1}, true).outputVariable(); - } - - - public SDVariable gt(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarGreaterThan(sameDiff(), functionInput, functionInput1, false).outputVariable(); - } - - - public SDVariable lt(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarLessThan(sameDiff(), functionInput, functionInput1, false).outputVariable(); - } - - - public SDVariable gti(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarGreaterThan(sameDiff(), functionInput, functionInput1, true).outputVariable(); - } - - - public SDVariable lti(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarLessThan(sameDiff(), functionInput, functionInput1, true).outputVariable(); - } - - - public SDVariable gte(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarGreaterThanOrEqual(sameDiff(), functionInput, functionInput1, false).outputVariable(); - } - - - public SDVariable lte(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarLessThanOrEqual(sameDiff(), functionInput, functionInput1, false).outputVariable(); - } - - - public SDVariable gtei(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarGreaterThanOrEqual(sameDiff(), functionInput, functionInput1, true).outputVariable(); - } - - - public SDVariable ltei(SDVariable functionInput, double functionInput1) { - validateDifferentialFunctionsameDiff(functionInput); - return new ScalarLessThanOrEqual(sameDiff(), functionInput, functionInput1, true).outputVariable(); - } - - - public SDVariable eq(SDVariable iX, double i_y) { - return new ScalarEquals(sameDiff(), iX, i_y).outputVariable(); - } - - public SDVariable eqi(SDVariable iX, double i_y) { - return new ScalarEquals(sameDiff(), iX, i_y, true).outputVariable(); - } - - public SDVariable isNonDecreasing(SDVariable iX) { - validateDifferentialFunctionsameDiff(iX); - return new IsNonDecreasing(sameDiff(), new SDVariable[]{iX}, false).outputVariable(); - } - - public SDVariable isStrictlyIncreasing(SDVariable iX) { - validateDifferentialFunctionsameDiff(iX); - return new IsStrictlyIncreasing(sameDiff(), new SDVariable[]{iX}, false).outputVariable(); - } - - public SDVariable isNumericTensor(SDVariable iX) { - validateDifferentialFunctionsameDiff(iX); - return new IsNumericTensor(sameDiff(), new SDVariable[]{iX}, false).outputVariable(); - } - - public SDVariable slice(SDVariable input, int[] begin, int[] size) { - return new Slice(sameDiff(), input, begin, size).outputVariable(); - } - - public SDVariable slice(SDVariable input, SDVariable begin, SDVariable size) { - return new Slice(sameDiff(), input, begin, size).outputVariable(); - } - - public SDVariable sliceBp(SDVariable input, SDVariable gradient, int[] begin, int[] size) { - return new SliceBp(sameDiff(), input, gradient, begin, size).outputVariable(); - } - - public SDVariable sliceBp(SDVariable input, SDVariable gradient, SDVariable begin, SDVariable size) { - return new SliceBp(sameDiff(), input, gradient, begin, size).outputVariable(); - } - - - public SDVariable stridedSlice(SDVariable input, int[] begin, int[] end, int[] strides) { - return new StridedSlice(sameDiff(), input, begin, end, strides).outputVariable(); - } - - public SDVariable stridedSlice(SDVariable input, long[] begin, long[] end, long[] strides) { - return new StridedSlice(sameDiff(), input, begin, end, strides).outputVariable(); - } - - - public SDVariable stridedSlice(SDVariable in, int[] begin, int[] end, int[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return new StridedSlice(sameDiff(), in, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask).outputVariable(); - } - - public SDVariable stridedSlice(SDVariable in, long[] begin, long[] end, long[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return new StridedSlice(sameDiff(), in, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask).outputVariable(); - } - - public SDVariable stridedSliceBp(SDVariable in, SDVariable grad, long[] begin, long[] end, long[] strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return new StridedSliceBp(sameDiff(), in, grad, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask).outputVariable(); - } - - public SDVariable stridedSliceBp(SDVariable in, SDVariable grad, SDVariable begin, SDVariable end, SDVariable strides, int beginMask, - int endMask, int ellipsisMask, int newAxisMask, int shrinkAxisMask) { - return new StridedSliceBp(sameDiff(), in, grad, begin, end, strides, beginMask, endMask, ellipsisMask, - newAxisMask, shrinkAxisMask).outputVariable(); - } - - public SDVariable scatterAdd(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterAdd(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterSub(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterSub(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterMul(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterMul(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterDiv(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterDiv(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterMax(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterMax(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterMin(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterMin(sameDiff(), ref, indices, updates).outputVariable(); - } - - public SDVariable scatterUpdate(SDVariable ref, SDVariable indices, SDVariable updates) { - return new ScatterUpdate(sameDiff(), ref, indices, updates).outputVariable(); - } - - - public SDVariable merge(SDVariable... inputs){ - return new Merge(sameDiff(), inputs).outputVariable(); - } - - public SDVariable[] switchOp(SDVariable input, SDVariable predicate){ - return new Switch(sameDiff(), input, predicate).outputVariables(); - } - - - public void validateDifferentialFunctionsameDiff( - SDVariable function) { - - Preconditions.checkState(function != null, "Passed in function was null."); - Preconditions.checkState(function.getSameDiff() == sameDiff); - - Preconditions.checkState(function.getSameDiff() == this.getSameDiff(), - "Function applications must be contained " + - "in same sameDiff. The left %s must match this function %s", function, this); - Preconditions.checkState(sameDiff == this.getSameDiff(), "Function applications must be " + - "contained in same sameDiff. The left %s must match this function ", function, this); - } - - - public void validateDifferentialFunctionGraph(SDVariable function) { - Preconditions.checkState(function.getSameDiff() == this.getSameDiff(), - "Function applications must be contained in same graph. The left %s must match this function %s", - function, this); - - } - - - /** - * @param func - * @param input - * @return - */ - public SDVariable doRepeat(SDVariable func, - SDVariable input) { - validateDifferentialFunctionsameDiff(func); - validateDifferentialFunctionsameDiff(input); - - return tile(func, ArrayUtil.toInts(input.getShape())); - } - - public SDVariable enter(SDVariable x, String frameName){ - return new Enter(sameDiff, frameName, x).outputVariable(); - } - - public SDVariable enter(SDVariable x, String frameName, boolean isConstant){ - return new Enter(sameDiff, frameName, x, isConstant).outputVariable(); - } - - public SDVariable exit(SDVariable x){ - return new Exit(sameDiff, x).outputVariable(); - } - - public SDVariable nextIteration(SDVariable x){ - return new NextIteration(sameDiff, x).outputVariable(); - } - - public SDVariable adjustContrast(SDVariable in, SDVariable factor) { - return new AdjustContrast(sameDiff, in, factor).outputVariable(); - } - - public SDVariable adjustContrastV2(SDVariable in, SDVariable factor) { - return new AdjustContrastV2(sameDiff, in, factor).outputVariable(); - } - - public SDVariable bitCast(SDVariable in, SDVariable dataType) { - return new BitCast(sameDiff, in, dataType).outputVariable(); - } - - public SDVariable compareAndBitpack(SDVariable threshold) { - return new CompareAndBitpack(sameDiff, threshold).outputVariable(); - } - - public SDVariable divideNoNan(SDVariable in1, SDVariable in2) { - return new DivideNoNan(sameDiff, in1, in2).outputVariable(); - } - - public SDVariable drawBoundingBoxes(SDVariable boxes, SDVariable colors) { - return new DrawBoundingBoxes(sameDiff, boxes, colors).outputVariable(); - } - - public SDVariable fakeQuantWithMinMaxVarsPerChannel(SDVariable x, SDVariable min, SDVariable max, - int num_bits, boolean narrow) { - return new FakeQuantWithMinMaxVarsPerChannel(sameDiff,x,min,max,num_bits,narrow).outputVariable(); - } - - public SDVariable betainc( SDVariable a, SDVariable b, SDVariable x) { - return new BetaInc(sameDiff, a, b, x).outputVariable(); - } - - public SDVariable[] fusedBatchNorm(SDVariable x, SDVariable scale, SDVariable offset, - SDVariable dataFormat, SDVariable isTraining) { - return new FusedBatchNorm(sameDiff,x,scale,offset,dataFormat,isTraining).outputVariables(); - } - - public SDVariable matrixBandPart(SDVariable input, SDVariable minLower, SDVariable maxUpper) { - return new MatrixBandPart(sameDiff,input,minLower,maxUpper).outputVariable(); - } - - public SDVariable[] maxPoolWithArgmax(SDVariable x, Pooling2DConfig pooling2DConfig) { - return new MaxPoolWithArgmax(sameDiff, x, pooling2DConfig).outputVariables(); - } - - public SDVariable polygamma(SDVariable n, SDVariable x) { - return new Polygamma(sameDiff, n,x).outputVariable(); - } - - public SDVariable roll(SDVariable input, int shift) { - return new Roll(sameDiff, input, shift).outputVariable(); - } - - public SDVariable toggleBits(SDVariable x) { - return new ToggleBits(sameDiff, x).outputVariable(); - } - - - public String toString() { - return "DifferentialFunctionFactory{methodNames=" + methodNames + "}"; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index 3b29e6ccb..5ee0801d0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -253,7 +253,7 @@ public class SDVariable implements Serializable { * @return Negated variable */ public SDVariable neg(){ - return sameDiff.f().neg(this); + return sameDiff.math.neg(this); } /** @@ -579,7 +579,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable add(String varName, double scalar) { - val function = sameDiff.f().add(this,scalar); + val function = sameDiff.math.add(this,scalar); return sameDiff.updateVariableNameAndReference(function,varName); } @@ -600,7 +600,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable add(String name, SDVariable x) { - val result = sameDiff.f().add(this, x); + val result = sameDiff.math.add(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -636,7 +636,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable sub(String varName, double scalar) { - val result = sameDiff.f().sub(this, scalar); + val result = sameDiff.math.sub(this, scalar); return sameDiff.updateVariableNameAndReference(result, varName); } @@ -657,7 +657,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable sub(String name, SDVariable x) { - val result = sameDiff.f().sub(this,x); + val result = sameDiff.math.sub(this,x); return sameDiff.updateVariableNameAndReference(result,name); } @@ -693,7 +693,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable div(String varName, double scalar) { - val function = sameDiff.f().div(this,scalar); + val function = sameDiff.math.div(this,scalar); return sameDiff.updateVariableNameAndReference(function,varName); } @@ -714,7 +714,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable div(String name, SDVariable x) { - val result = sameDiff.f().div(this, x); + val result = sameDiff.math.div(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -728,7 +728,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable fdiv(String name, SDVariable x) { - val result = sameDiff.f().floorDiv(this, x); + val result = sameDiff.math.floorDiv(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -742,7 +742,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable mod(String name, SDVariable x) { - val result = sameDiff.f().mod(this, x); + val result = sameDiff.math.mod(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -762,7 +762,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable mul(String varName, double scalar) { - val function = sameDiff.f().mul(this, scalar); + val function = sameDiff.math.mul(this, scalar); return sameDiff.updateVariableNameAndReference(function,varName); } @@ -784,7 +784,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable mul(String name, SDVariable x) { - val result = sameDiff.f().mul(this, x); + val result = sameDiff.math.mul(this, x); return sameDiff.updateVariableNameAndReference(result,name); } @@ -820,7 +820,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable pow(String varName, double scalar) { - SDVariable ret = sameDiff.f().pow(this, scalar); + SDVariable ret = sameDiff.math.pow(this, scalar); return sameDiff.updateVariableNameAndReference(ret, varName); } @@ -840,7 +840,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable rsub(String varName, double scalar) { - val function = sameDiff.f().rsub(this,scalar); + val function = sameDiff.math.rsub(this,scalar); return sameDiff.updateVariableNameAndReference(function,varName); } @@ -861,7 +861,7 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable rsub(String name, SDVariable x) { - val result = sameDiff.f().rsub(this,x); + val result = sameDiff.math.rsub(this,x); return sameDiff.updateVariableNameAndReference(result,name); } @@ -881,7 +881,7 @@ public class SDVariable implements Serializable { * @return Output variable */ public SDVariable rdiv(String varName, double scalar) { - val function = sameDiff.f().rdiv(this, scalar); + val function = sameDiff.math.rdiv(this, scalar); return sameDiff.updateVariableNameAndReference(function, varName); } @@ -902,34 +902,11 @@ public class SDVariable implements Serializable { * @return Output (result) SDVariable */ public SDVariable rdiv(String name, SDVariable x) { - val result = sameDiff.f().rdiv(this,x); + val result = sameDiff.math.rdiv(this,x); return sameDiff.updateVariableNameAndReference(result,name); } - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable truncatedDiv(SDVariable sameDiffVariable) { - return truncatedDiv(null,sameDiffVariable); - - } - - - - /** - * - * @param sameDiffVariable - * @return - */ - public SDVariable truncatedDiv(String varName, SDVariable sameDiffVariable) { - val function = sameDiff.f().truncatedDiv(this, sameDiffVariable); - return sameDiff.updateVariableNameAndReference(function,varName); - - } - /** * See {@link #squaredDifference(String, SDVariable)} */ @@ -943,7 +920,7 @@ public class SDVariable implements Serializable { * @return squared difference between variables */ public SDVariable squaredDifference(String name, SDVariable x) { - val result = sameDiff.f().squaredDifference(this, x); + val result = sameDiff.math().squaredDifference(this, x); return sameDiff.updateVariableNameAndReference(result, name); } @@ -1431,7 +1408,7 @@ public class SDVariable implements Serializable { } public SDVariable permute(SDVariable dimensions){ - return sameDiff.permute(null, this, dimensions); + return sameDiff.permute( this, dimensions); } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java index c51ac28a1..77d46b889 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SameDiff.java @@ -24,7 +24,6 @@ import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.listeners.*; import org.nd4j.autodiff.listeners.impl.HistoryListener; import org.nd4j.autodiff.listeners.records.History; @@ -53,8 +52,7 @@ import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge; -import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; +import org.nd4j.linalg.api.ops.impl.controlflow.compat.*; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; import org.nd4j.linalg.api.ops.impl.transforms.Assert; @@ -95,7 +93,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.regex.Matcher; import java.util.regex.Pattern; -import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs; +import static org.nd4j.autodiff.util.SameDiffUtils.stackOutputs; /** * SameDiff is the entrypoint for ND4J's automatic differentiation functionality. @@ -141,7 +139,7 @@ public class SameDiff extends SDBaseOps { //////////////////////////////////////// - private DifferentialFunctionFactory functionFactory; +// private DifferentialFunctionFactory functionFactory; // counter for auto-naming variables private int variableId = 0; @@ -296,15 +294,6 @@ public class SameDiff extends SDBaseOps { return this; } - /** - * Returns this samediff instance's {@link DifferentialFunctionFactory} - * - * @return DifferentialFunctionFactory - */ - public DifferentialFunctionFactory f() { - return functionFactory; - } - /** * Set the current SameDiff-wide {@link Listener} instances. * @@ -917,7 +906,6 @@ public class SameDiff extends SDBaseOps { private SameDiff() { super(null); super.sd = this; - functionFactory = new DifferentialFunctionFactory(this); sameDiffFunctionInstances = new LinkedHashMap<>(); fieldVariableResolutionMapping = HashBasedTable.create(); } @@ -5945,7 +5933,7 @@ public class SameDiff extends SDBaseOps { if(switches.containsKey(argument.name())) return switches.get(argument.name())[1]; - SDVariable[] s = f().switchOp(argument, pred); + SDVariable[] s = switchOp(argument, pred); switches.put(argument.name(), s); return s[1]; } @@ -5955,7 +5943,7 @@ public class SameDiff extends SDBaseOps { this.removeArgumentInterceptor(); if(declared.contains(trueOut.name())) { - SDVariable[] s = f().switchOp(trueOut, pred); + SDVariable[] s = switchOp(trueOut, pred); switches.put(trueOut.name(), s); trueOut = s[1]; } @@ -5975,7 +5963,7 @@ public class SameDiff extends SDBaseOps { if(switches.containsKey(argument.name())) return switches.get(argument.name())[0]; - SDVariable[] s = f().switchOp(argument, pred); + SDVariable[] s = switchOp(argument, pred); switches.put(argument.name(), s); return s[0]; } @@ -5985,13 +5973,13 @@ public class SameDiff extends SDBaseOps { this.removeArgumentInterceptor(); if(declared2.contains(falseOut.name())) { - SDVariable[] s = f().switchOp(falseOut, pred); + SDVariable[] s = switchOp(falseOut, pred); switches.put(falseOut.name(), s); falseOut = s[0]; } falseScope.close(); - SDVariable output = f().merge(trueOut, falseOut); + SDVariable output = merge(trueOut, falseOut); ifScope.close(); @@ -6042,11 +6030,9 @@ public class SameDiff extends SDBaseOps { SDVariable[] entered = new SDVariable[loopVars.length]; for(int i = 0 ; i < loopVars.length ; i++){ - entered[i] = f().enter(loopVars[i], frameName); + entered[i] = new Enter(this, frameName, loopVars[i]).outputVariable(); } - //counter = SD.f().enter(counter, frameName); - SDVariable[] merged = new SDVariable[loopVars.length]; Merge[] mergeOps = new Merge[loopVars.length]; for(int i = 0 ; i < loopVars.length ; i++){ @@ -6072,19 +6058,16 @@ public class SameDiff extends SDBaseOps { SDVariable[] trueSwitches = new SDVariable[loopVars.length]; SDVariable[] exits = new SDVariable[loopVars.length]; for(int i = 0 ; i < loopVars.length ; i++){ - SDVariable[] s = f().switchOp(merged[i], cond_result); + SDVariable[] s = switchOp(merged[i], cond_result); trueSwitches[i] = s[1]; alreadyEntered.add(s[1].name()); - exits[i] = f().exit(s[0]); + exits[i] = new Exit(this, s[0]).outputVariable(); } - //SDVariable[] cs = SD.f().switchOp(counter, cond_result); - //SDVariable counterExit = SD.f().exit(cs[0]); - //counter = cs[1]; - final Set declared = Sets.newHashSet(this.variableMap().keySet()); final Map done = new HashMap<>(); + final SameDiff sd = this; this.addArgumentInterceptor(new ArgumentInterceptor() { @Override public SDVariable intercept(SDVariable argument) { @@ -6098,7 +6081,7 @@ public class SameDiff extends SDBaseOps { if(done.containsKey(argument.name())) return done.get(argument.name()); - SDVariable e = f().enter(argument, frameName, true); + SDVariable e = new Enter(sd, frameName, argument, true).outputVariable(); done.put(argument.name(), e); return e; } @@ -6112,7 +6095,7 @@ public class SameDiff extends SDBaseOps { //counter.add(1); for(int i = 0 ; i < loopVars.length ; i++){ - SDVariable n = f().nextIteration(outs[i]); + SDVariable n = new NextIteration(this, outs[i]).outputVariable(); mergeOps[i].replaceArg(1,n); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java index 5cafa09aa..193229ff9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/config/OutputConfig.java @@ -27,7 +27,7 @@ import lombok.Setter; import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.autodiff.util.TrainingUtils; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.adapter.MultiDataSetIteratorAdapter; @@ -165,7 +165,7 @@ public class OutputConfig { Preconditions.checkState(outputs.size() == 1, "Can only use execSingleBatches() when exactly one output is specified, there were %s", outputs.size()); - return TrainingUtils + return SameDiffUtils .getSingleOutput(sd.outputBatches(data, listeners, outputs.toArray(new String[0])), outputs.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 3b53e5b65..157ec1fbb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -37,12 +37,11 @@ public class SDBaseOps { /** * Boolean and array reduction operation, optionally along specified dimensions
* - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public SDVariable all(SDVariable x, int... dimensions) { - SDValidation.validateBool("all", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return new org.nd4j.linalg.api.ops.impl.reduce.bool.All(sd,x, dimensions).outputVariable(); } @@ -51,12 +50,11 @@ public class SDBaseOps { * Boolean and array reduction operation, optionally along specified dimensions
* * @param name name May be null. Name for the output variable - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public SDVariable all(String name, SDVariable x, int... dimensions) { - SDValidation.validateBool("all", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.bool.All(sd,x, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); @@ -65,12 +63,11 @@ public class SDBaseOps { /** * Boolean or array reduction operation, optionally along specified dimensions
* - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public SDVariable any(SDVariable x, int... dimensions) { - SDValidation.validateBool("any", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(sd,x, dimensions).outputVariable(); } @@ -79,12 +76,11 @@ public class SDBaseOps { * Boolean or array reduction operation, optionally along specified dimensions
* * @param name name May be null. Name for the output variable - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public SDVariable any(String name, SDVariable x, int... dimensions) { - SDValidation.validateBool("any", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); SDVariable out = new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(sd,x, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); @@ -196,6 +192,8 @@ public class SDBaseOps { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param in Input variable (NUMERIC type) * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions @@ -220,6 +218,8 @@ public class SDBaseOps { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable * @param in Input variable (NUMERIC type) @@ -246,6 +246,8 @@ public class SDBaseOps { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param in Input variable (NUMERIC type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) @@ -269,6 +271,8 @@ public class SDBaseOps { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable * @param in Input variable (NUMERIC type) @@ -744,6 +748,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -762,6 +768,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -964,6 +972,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -982,6 +992,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1032,6 +1044,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1050,6 +1064,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1245,6 +1261,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1263,6 +1281,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1313,6 +1333,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1331,6 +1353,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1581,6 +1605,8 @@ public class SDBaseOps { * Element-wise maximum operation: out[i] = max(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) @@ -1596,6 +1622,8 @@ public class SDBaseOps { * Element-wise maximum operation: out[i] = max(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable * @param first First input array (NUMERIC type) @@ -1695,6 +1723,38 @@ public class SDBaseOps { return sd.updateVariableNameAndReference(out, name); } + /** + * The merge operation is a control operation that forwards the either of the inputs to the output, when
+ * the first of them becomes available. If both are available, the output is undefined (either input could
+ * be forwarded to the output)
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public SDVariable merge(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("merge", "x", x); + SDValidation.validateNumerical("merge", "y", y); + return new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(sd,x, y).outputVariable(); + } + + /** + * The merge operation is a control operation that forwards the either of the inputs to the output, when
+ * the first of them becomes available. If both are available, the output is undefined (either input could
+ * be forwarded to the output)
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public SDVariable merge(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("merge", "x", x); + SDValidation.validateNumerical("merge", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
* @@ -1785,6 +1845,8 @@ public class SDBaseOps { * Element-wise minimum operation: out[i] = min(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) @@ -1800,6 +1862,8 @@ public class SDBaseOps { * Element-wise minimum operation: out[i] = min(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param name name May be null. Name for the output variable * @param first First input array (NUMERIC type) @@ -1916,6 +1980,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1934,6 +2000,8 @@ public class SDBaseOps { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -4176,6 +4244,32 @@ public class SDBaseOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Switch operation
+ * Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output
+ * + * @param x Input variable (NDARRAY type) + * @param predicate Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output (BOOL type) + */ + public SDVariable[] switchOp(SDVariable x, SDVariable predicate) { + SDValidation.validateBool("switchOp", "predicate", predicate); + return new org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch(sd,x, predicate).outputVariables(); + } + + /** + * Switch operation
+ * Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output
+ * + * @param names names May be null. Arrays of names for the output variables. + * @param x Input variable (NDARRAY type) + * @param predicate Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output (BOOL type) + */ + public SDVariable[] switchOp(String[] names, SDVariable x, SDVariable predicate) { + SDValidation.validateBool("switchOp", "predicate", predicate); + SDVariable[] out = new org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch(sd,x, predicate).outputVariables(); + return sd.updateVariableNamesAndReferences(out, names); + } + /** * //TODO: Ops must be documented.
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index a58d4d180..ef030e952 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -258,7 +258,7 @@ public class SDImage extends SDOps { /** * Resize images to size using the specified method.
* - * @param input 4D image [NHWC] (NUMERIC type) + * @param input 4D image [NCHW] (NUMERIC type) * @param size new height and width (INT type) * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. * @param antialis Whether to use an anti-aliasing filter when downsampling an image @@ -282,7 +282,7 @@ public class SDImage extends SDOps { * Resize images to size using the specified method.
* * @param name name May be null. Name for the output variable - * @param input 4D image [NHWC] (NUMERIC type) + * @param input 4D image [NCHW] (NUMERIC type) * @param size new height and width (INT type) * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. * @param antialis Whether to use an anti-aliasing filter when downsampling an image @@ -306,7 +306,7 @@ public class SDImage extends SDOps { /** * Resize images to size using the specified method.
* - * @param input 4D image [NHWC] (NUMERIC type) + * @param input 4D image [NCHW] (NUMERIC type) * @param size new height and width (INT type) * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. @@ -328,7 +328,7 @@ public class SDImage extends SDOps { * Resize images to size using the specified method.
* * @param name name May be null. Name for the output variable - * @param input 4D image [NHWC] (NUMERIC type) + * @param input 4D image [NCHW] (NUMERIC type) * @param size new height and width (INT type) * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 1f89ba1d1..66d47f905 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -67,13 +67,13 @@ public class SDMath extends SDOps { * Looks up ids in a list of embedding tensors.
* * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param indices A Tensor containing the ids to be looked up. (NUMERIC type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ public SDVariable embeddingLookup(SDVariable x, SDVariable indices, PartitionMode PartitionMode) { SDValidation.validateNumerical("EmbeddingLookup", "x", x); - SDValidation.validateInteger("EmbeddingLookup", "indices", indices); + SDValidation.validateNumerical("EmbeddingLookup", "indices", indices); return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); } @@ -82,14 +82,14 @@ public class SDMath extends SDOps { * * @param name name May be null. Name for the output variable * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param indices A Tensor containing the ids to be looked up. (NUMERIC type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ public SDVariable embeddingLookup(String name, SDVariable x, SDVariable indices, PartitionMode PartitionMode) { SDValidation.validateNumerical("EmbeddingLookup", "x", x); - SDValidation.validateInteger("EmbeddingLookup", "indices", indices); + SDValidation.validateNumerical("EmbeddingLookup", "indices", indices); SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -166,6 +166,68 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise addition operation, out = x + y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable add(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("add", "x", x); + SDValidation.validateNumerical("add", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise addition operation, out = x + y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable add(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("add", "x", x); + SDValidation.validateNumerical("add", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar add operation, out = in + scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable add(SDVariable x, double value) { + SDValidation.validateNumerical("add", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd,x, value).outputVariable(); + } + + /** + * Scalar add operation, out = in + scalar
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable add(String name, SDVariable x, double value) { + SDValidation.validateNumerical("add", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
* @@ -1126,6 +1188,68 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise division operation, out = x / y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable div(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("div", "x", x); + SDValidation.validateNumerical("div", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise division operation, out = x / y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable div(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("div", "x", x); + SDValidation.validateNumerical("div", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar division operation, out = in / scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable div(SDVariable x, double value) { + SDValidation.validateNumerical("div", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd,x, value).outputVariable(); + } + + /** + * Scalar division operation, out = in / scalar
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable div(String name, SDVariable x, double value) { + SDValidation.validateNumerical("div", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Entropy reduction: -sum(x * log(x))
* @@ -1552,6 +1676,104 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise floor division operation, out = floor(x / y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorDiv(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("floorDiv", "x", x); + SDValidation.validateNumerical("floorDiv", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise floor division operation, out = floor(x / y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorDiv(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("floorDiv", "x", x); + SDValidation.validateNumerical("floorDiv", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Pairwise Modulus division operation
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorMod(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("floorMod", "x", x); + SDValidation.validateNumerical("floorMod", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise Modulus division operation
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorMod(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("floorMod", "x", x); + SDValidation.validateNumerical("floorMod", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar floor modulus operation
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorMod(SDVariable x, double value) { + SDValidation.validateNumerical("floorMod", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,x, value).outputVariable(); + } + + /** + * Scalar floor modulus operation
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable floorMod(String name, SDVariable x, double value) { + SDValidation.validateNumerical("floorMod", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Hamming distance reduction operation. The output contains the cosine distance for each
* tensor/subset along the specified dimensions:
@@ -2260,6 +2482,42 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise max operation, out = max(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public SDVariable max(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("max", "x", x); + SDValidation.validateNumerical("max", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,x, y).outputVariable(); + } + + /** + * Pairwise max operation, out = max(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public SDVariable max(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("max", "x", x); + SDValidation.validateNumerical("max", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
* out = sum_i in[i]
@@ -2370,6 +2628,78 @@ public class SDMath extends SDOps { return sd.updateVariableNamesAndReferences(out, names); } + /** + * Pairwise max operation, out = min(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public SDVariable min(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("min", "x", x); + SDValidation.validateNumerical("min", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,x, y).outputVariable(); + } + + /** + * Pairwise max operation, out = min(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public SDVariable min(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("min", "x", x); + SDValidation.validateNumerical("min", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Pairwise modulus (remainder) operation, out = x % y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mod(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mod", "x", x); + SDValidation.validateNumerical("mod", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise modulus (remainder) operation, out = x % y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mod(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mod", "x", x); + SDValidation.validateNumerical("mod", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Calculate the mean and (population) variance for the input variable, for the specified axis
* @@ -2396,6 +2726,68 @@ public class SDMath extends SDOps { return sd.updateVariableNamesAndReferences(out, names); } + /** + * Pairwise multiplication operation, out = x * y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mul(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mul", "x", x); + SDValidation.validateNumerical("mul", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise multiplication operation, out = x * y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable mul(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("mul", "x", x); + SDValidation.validateNumerical("mul", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar multiplication operation, out = in * scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable mul(SDVariable x, double value) { + SDValidation.validateNumerical("mul", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd,x, value).outputVariable(); + } + + /** + * Scalar multiplication operation, out = in * scalar
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable mul(String name, SDVariable x, double value) { + SDValidation.validateNumerical("mul", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Elementwise negative operation: out = -x
* @@ -2542,6 +2934,96 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Rational Tanh Approximation elementwise function, as described in the paper:
+ * Compact Convolutional Neural Network Cascade for Face Detection
+ * This is a faster Tanh approximation
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rationalTanh(SDVariable x) { + SDValidation.validateNumerical("rationalTanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd,x).outputVariable(); + } + + /** + * Rational Tanh Approximation elementwise function, as described in the paper:
+ * Compact Convolutional Neural Network Cascade for Face Detection
+ * This is a faster Tanh approximation
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rationalTanh(String name, SDVariable x) { + SDValidation.validateNumerical("rationalTanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Pairwise reverse division operation, out = y / x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rdiv(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("rdiv", "x", x); + SDValidation.validateNumerical("rdiv", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise reverse division operation, out = y / x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rdiv(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("rdiv", "x", x); + SDValidation.validateNumerical("rdiv", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar reverse division operation, out = scalar / in
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable rdiv(SDVariable x, double value) { + SDValidation.validateNumerical("rdiv", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd,x, value).outputVariable(); + } + + /** + * Scalar reverse division operation, out = scalar / in
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable rdiv(String name, SDVariable x, double value) { + SDValidation.validateNumerical("rdiv", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
* @@ -2566,6 +3048,30 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Rectified tanh operation: max(0, tanh(in))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rectifiedTanh(SDVariable x) { + SDValidation.validateNumerical("rectifiedTanh", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd,x).outputVariable(); + } + + /** + * Rectified tanh operation: max(0, tanh(in))
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rectifiedTanh(String name, SDVariable x) { + SDValidation.validateNumerical("rectifiedTanh", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Element-wise round function: out = round(x).
* Rounds (up or down depending on value) to the nearest integer value.
@@ -2616,6 +3122,68 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise reverse subtraction operation, out = y - x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsub(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("rsub", "x", x); + SDValidation.validateNumerical("rsub", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise reverse subtraction operation, out = y - x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsub(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("rsub", "x", x); + SDValidation.validateNumerical("rsub", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar reverse subtraction operation, out = scalar - in
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsub(SDVariable x, double value) { + SDValidation.validateNumerical("rsub", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd,x, value).outputVariable(); + } + + /** + * Scalar reverse subtraction operation, out = scalar - in
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable rsub(String name, SDVariable x, double value) { + SDValidation.validateNumerical("rsub", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Set the diagonal value to the specified values
* If input is
@@ -2814,6 +3382,42 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise squared difference operation.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable squaredDifference(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("squaredDifference", "x", x); + SDValidation.validateNumerical("squaredDifference", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise squared difference operation.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable squaredDifference(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("squaredDifference", "x", x); + SDValidation.validateNumerical("squaredDifference", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Standardize input variable along given axis
*


@@ -2894,6 +3498,68 @@ public class SDMath extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Pairwise subtraction operation, out = x - y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sub(SDVariable x, SDVariable y) { + SDValidation.validateNumerical("sub", "x", x); + SDValidation.validateNumerical("sub", "y", y); + return new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd,x, y).outputVariable(); + } + + /** + * Pairwise subtraction operation, out = x - y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable sub(String name, SDVariable x, SDVariable y) { + SDValidation.validateNumerical("sub", "x", x); + SDValidation.validateNumerical("sub", "y", y); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(sd,x, y).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Scalar subtraction operation, out = in - scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable sub(SDVariable x, double value) { + SDValidation.validateNumerical("sub", "x", x); + return new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd,x, value).outputVariable(); + } + + /** + * Scalar subtraction operation, out = in - scalar
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public SDVariable sub(String name, SDVariable x, double value) { + SDValidation.validateNumerical("sub", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(sd,x, value).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Elementwise tangent operation: out = tan(x)
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 9633a0186..15d70aac5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -24,6 +24,7 @@ import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; public class SDNN extends SDOps { public SDNN(SameDiff sameDiff) { @@ -722,6 +723,39 @@ public class SDNN extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Padding operation
+ * + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public SDVariable pad(SDVariable input, SDVariable padding, PadMode PadMode, double constant) { + SDValidation.validateNumerical("pad", "input", input); + SDValidation.validateNumerical("pad", "padding", padding); + return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode, constant).outputVariable(); + } + + /** + * Padding operation
+ * + * @param name name May be null. Name for the output variable + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public SDVariable pad(String name, SDVariable input, SDVariable padding, PadMode PadMode, + double constant) { + SDValidation.validateNumerical("pad", "input", input); + SDValidation.validateNumerical("pad", "padding", padding); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode, constant).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Padding operation
* @@ -733,7 +767,7 @@ public class SDNN extends SDOps { public SDVariable pad(SDVariable input, SDVariable padding, double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, constant).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode.CONSTANT, constant).outputVariable(); } /** @@ -748,7 +782,35 @@ public class SDNN extends SDOps { public SDVariable pad(String name, SDVariable input, SDVariable padding, double constant) { SDValidation.validateNumerical("pad", "input", input); SDValidation.validateNumerical("pad", "padding", padding); - SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, constant).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.Pad(sd,input, padding, PadMode.CONSTANT, constant).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the precise method
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable preciseGelu(SDVariable x) { + SDValidation.validateNumerical("preciseGelu", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd,x).outputVariable(); + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the precise method
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable preciseGelu(String name, SDVariable x) { + SDValidation.validateNumerical("preciseGelu", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(sd,x).outputVariable(); return sd.updateVariableNameAndReference(out, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java index 88792bddb..fc406caa6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDOps.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2019-2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,20 +14,21 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + package org.nd4j.autodiff.samediff.ops; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; +import static org.nd4j.autodiff.samediff.ops.SDValidation.isSameType; + +import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.indexing.conditions.Condition; -/** - * Abstract class for defining categories of operations - such as {@link SDMath} that is available via {@code SameDiff.math()} - * - * @author Alex Black - */ -public abstract class SDOps { - - protected final SameDiff sd; +public class SDOps { + protected SameDiff sd; public SDOps() { sd = null; @@ -37,11 +38,5 @@ public abstract class SDOps { this.sd = sameDiff; } - protected DifferentialFunctionFactory f() { - return sd.f(); - } - protected SDVariable updateVariableNameAndReference(SDVariable varToUpdate, String newVarName) { - return sd.updateVariableNameAndReference(varToUpdate, newVarName); - } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/OpPredicate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/OpPredicate.java index 97a47d257..2b91300eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/OpPredicate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/transform/OpPredicate.java @@ -21,7 +21,6 @@ import org.nd4j.autodiff.samediff.SameDiff; /** * An OpPredicate defines whether an operation ({@link DifferentialFunction}) matches or not.
- * Used mainly in {@link org.nd4j.autodiff.functions.DifferentialFunctionFactory} * * @author Alex Black */ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java new file mode 100644 index 000000000..a3f9ddea2 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/SameDiffUtils.java @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.autodiff.util; + +import java.util.*; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.functions.DifferentialFunction; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; +import org.nd4j.linalg.api.ops.impl.shape.ReductionShape; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; +import org.nd4j.linalg.exception.ND4JException; +import org.nd4j.linalg.factory.Nd4j; + +/** + * Utilities for SameDiff training and inference + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public class SameDiffUtils { + + /** + * Stack batch outputs, like an output from {@link org.nd4j.autodiff.samediff.SameDiff#output(MultiDataSetIterator, String...)} + */ + public static Map stackOutputs(List> outputs){ + Map> outs = new HashMap<>(); + for(Map batch : outputs){ + for(String k : batch.keySet()){ + if(!outs.containsKey(k)) + outs.put(k, new ArrayList()); + outs.get(k).add(batch.get(k)); + } + } + + Map ret = new HashMap<>(); + for(String k : outs.keySet()){ + try { + ret.put(k, Nd4j.concat(0, outs.get(k).toArray(new INDArray[0]))); + } catch(Exception e){ + throw new ND4JException("Error concatenating batch outputs", e); + } + } + return ret; + } + + /** + * Get a list of batch outputs for a single variable from a list of batch outputs for all variables + */ + public static List getSingleOutput(List> outputs, String output){ + List batches = new ArrayList<>(); + for(Map batch : outputs) + batches.add(batch.get(output)); + + return batches; + } + + public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, Map externalGradients, SDVariable... inputs) { + Preconditions.checkArgument(inputs != null && inputs.length > 0, "Require at least one SDVariable to" + + " be specified when using external errors: got %s", inputs); + ExternalErrorsFunction fn = new ExternalErrorsFunction(sameDiff, Arrays.asList(inputs), externalGradients); + fn.outputVariable(); + return fn; + } + + public static ExternalErrorsFunction externalErrors(SameDiff sameDiff, SDVariable[] inputs) { + return externalErrors(sameDiff, null, inputs); + } + + + + /** + * Add 1s as required to the array make an array possible to be broadcast with the original (pre-reduce) array. + *

+ * Example: if doing [a,b,c].sum(1), result is [a,c]. To 'undo' this in a way that can be auto-broadcast, + * we want to expand as required - i.e., [a,c] -> [a,1,c] which can be auto-broadcast with the original [a,b,c]. + * This is typically only used with reduction operations backprop. + * + * @param origRank Rank of the original array, before the reduction was executed + * @param reduceDims Dimensions that the original array was reduced from + * @param toExpand Array to add 1s to the shape to (such that it can be + * @return Reshaped array. + */ + public static SDVariable reductionBroadcastableWithOrigShape(int origRank, int[] reduceDims, SDVariable toExpand) { + if (Shape.isWholeArray(origRank, reduceDims)) { + //Output is [1,1] which is already broadcastable + return toExpand; + } else if (origRank == 2 && reduceDims.length == 1) { + //In this case: [a,b] -> [1,b] or [a,b] -> [a,1] + //both are already broadcastable + return toExpand; + } else { + //Example: [a,b,c].sum(1) -> [a,c]... want [a,1,c] + for (int d : reduceDims) { + toExpand = toExpand.getSameDiff().expandDims(toExpand, d); + } + return toExpand; + } + } + + public static SDVariable reductionBroadcastableWithOrigShape(SDVariable origInput, SDVariable axis, SDVariable toExpand) { + SDVariable shape = origInput.shape(); + SDVariable reduceShape = reductionShape(shape, axis, true); + SDVariable reshaped = toExpand.reshape(reduceShape); + return reshaped; + } + + public static SDVariable reductionShape(SDVariable shape, SDVariable axis, boolean keepDim){ + return new ReductionShape(shape.getSameDiff(), shape, axis, keepDim).outputVariable(); + } + + public static void validateDifferentialFunctionSameDiff(SameDiff sameDiff, SDVariable function, DifferentialFunction op) { + + Preconditions.checkState(function != null, "Passed in function was null."); + Preconditions.checkState(function.getSameDiff() == sameDiff); + + Preconditions.checkState(function.getSameDiff() == sameDiff, + "Function applications must be contained " + + "in same sameDiff. The left %s must match this function %s", function, op); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java deleted file mode 100644 index 289bd15be..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/util/TrainingUtils.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.nd4j.autodiff.util; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import lombok.AccessLevel; -import lombok.NoArgsConstructor; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator; -import org.nd4j.linalg.exception.ND4JException; -import org.nd4j.linalg.factory.Nd4j; - -/** - * Utilities for SameDiff training and inference - */ -@NoArgsConstructor(access = AccessLevel.PRIVATE) -public class TrainingUtils { - - /** - * Stack batch outputs, like an output from {@link org.nd4j.autodiff.samediff.SameDiff#output(MultiDataSetIterator, String...)} - */ - public static Map stackOutputs(List> outputs){ - Map> outs = new HashMap<>(); - for(Map batch : outputs){ - for(String k : batch.keySet()){ - if(!outs.containsKey(k)) - outs.put(k, new ArrayList()); - outs.get(k).add(batch.get(k)); - } - } - - Map ret = new HashMap<>(); - for(String k : outs.keySet()){ - try { - ret.put(k, Nd4j.concat(0, outs.get(k).toArray(new INDArray[0]))); - } catch(Exception e){ - throw new ND4JException("Error concatenating batch outputs", e); - } - } - return ret; - } - - /** - * Get a list of batch outputs for a single variable from a list of batch outputs for all variables - */ - public static List getSingleOutput(List> outputs, String output){ - List batches = new ArrayList<>(); - for(Map batch : outputs) - batches.add(batch.get(output)); - - return batches; - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PadMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PadMode.java new file mode 100644 index 000000000..4802ebdaf --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PadMode.java @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * Padding format */ +public enum PadMode { + CONSTANT, + + REFLECT, + + SYMMETRIC +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/WeightsFormat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/WeightsFormat.java new file mode 100644 index 000000000..865d23282 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/WeightsFormat.java @@ -0,0 +1,29 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * Weights format: [kH, kW, iC, oC] or [oC, iC, kH, kW], or [oC, kH, kW, iC] */ +public enum WeightsFormat { + YXIO, + + OIYX, + + OYXI +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java index afdc11aa4..e2ca9329e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastBoolOp.java @@ -22,6 +22,7 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -56,8 +57,6 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp int[] dimension) { super(sameDiff, inPlace, new Object[]{i_v2}); if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); this.sameDiff = sameDiff; this.inPlace = inPlace; this.dimension = dimension; @@ -80,9 +79,6 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp super(sameDiff, extraArgs); this.dimension = dimension; if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); - this.sameDiff = sameDiff; sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); @@ -107,7 +103,7 @@ public abstract class BaseBroadcastBoolOp extends BaseOp implements BroadcastOp super(sameDiff, inPlace, extraArgs); this.dimension = dimension; if (i_v != null) { - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); sameDiff.addArgsFor(new SDVariable[]{i_v},this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java index 291b66d2b..56201560a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseBroadcastOp.java @@ -22,6 +22,7 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -57,8 +58,6 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { int[] dimension) { super(sameDiff, inPlace, new Object[]{i_v2}); if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); this.sameDiff = sameDiff; this.inPlace = inPlace; this.dimension = dimension; @@ -80,8 +79,8 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { super(sameDiff, extraArgs); this.dimension = dimension; if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v1, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.sameDiff = sameDiff; sameDiff.addArgsFor(new SDVariable[]{i_v1,i_v2},this); @@ -107,7 +106,7 @@ public abstract class BaseBroadcastOp extends BaseOp implements BroadcastOp { super(sameDiff, inPlace, extraArgs); this.dimension = dimension; if (i_v != null) { - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); sameDiff.addArgsFor(new SDVariable[]{i_v},this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java index 8b598242c..502874cc1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseIndexAccumulation.java @@ -20,6 +20,7 @@ import lombok.Data; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -46,7 +47,6 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum super(sameDiff,null); if (i_v != null) { this.dimensions = dimensions; - f().validateDifferentialFunctionsameDiff(i_v); sameDiff.addArgsFor(new SDVariable[]{i_v},this); this.xVertexId = i_v.name(); @@ -65,8 +65,8 @@ public abstract class BaseIndexAccumulation extends BaseOp implements IndexAccum super(sameDiff,null); if (i_v != null) { this.dimensions = dimensions; - f().validateDifferentialFunctionsameDiff(i_v); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.xVertexId = i_v.name(); this.yVertexId = i_v2.name(); sameDiff.addArgsFor(new SDVariable[]{i_v,i_v2},this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java index 9e5b8f67b..66c3e95d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseReduceOp.java @@ -23,6 +23,7 @@ import lombok.val; import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -59,7 +60,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { dimensions = new int[] {Integer.MAX_VALUE}; this.dimensions = dimensions; - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); this.keepDims = keepDims; this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); @@ -83,8 +84,8 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp { this.xVertexId = i_v.name(); this.yVertexId = i_v2.name(); - f().validateDifferentialFunctionsameDiff(i_v); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.keepDims = keepDims; sameDiff.addArgsFor(new String[]{xVertexId,yVertexId},this); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java index 8cb7e50b4..858b6a81c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarBoolOp.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops; import lombok.extern.slf4j.Slf4j; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -73,7 +74,7 @@ public abstract class BaseScalarBoolOp extends BaseOp implements ScalarOp { if (i_v != null) { this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); } else { throw new IllegalArgumentException("Input not null variable."); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java index 254069929..66a204602 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseScalarOp.java @@ -21,6 +21,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.memory.MemoryWorkspace; @@ -94,7 +95,7 @@ public abstract class BaseScalarOp extends BaseOp implements ScalarOp { this.scalarValue = Nd4j.scalar(i_v.dataType(), scalar); this.xVertexId = i_v.name(); sameDiff.addArgsFor(new String[]{xVertexId},this); - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java index 4e498edeb..7f8e0487e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseTransformOp.java @@ -20,6 +20,7 @@ import lombok.extern.slf4j.Slf4j; import lombok.val; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; @@ -52,8 +53,8 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { boolean inPlace) { super(sameDiff,inPlace,new Object[] {i_v2}); if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v1, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.sameDiff = sameDiff; this.inPlace = inPlace; this.xVertexId = i_v1.name(); @@ -77,8 +78,8 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { super(sameDiff,extraArgs); if (i_v1 != null && i_v2 != null) { - f().validateDifferentialFunctionsameDiff(i_v1); - f().validateDifferentialFunctionsameDiff(i_v2); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v1, this); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v2, this); this.sameDiff = sameDiff; this.xVertexId = i_v1.name(); this.yVertexId = i_v2.name(); @@ -104,7 +105,7 @@ public abstract class BaseTransformOp extends BaseOp implements TransformOp { super(sameDiff,inPlace,extraArgs); if (i_v != null) { - f().validateDifferentialFunctionsameDiff(i_v); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, i_v, this); this.xVertexId = i_v.name(); sameDiff.addArgsFor(new SDVariable[]{i_v},this); } else { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java index b4cf2d05a..692571df9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/NoOp.java @@ -20,6 +20,7 @@ import onnx.Onnx; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; @@ -38,6 +39,10 @@ public class NoOp extends DynamicCustomOp { super("noop", sd, new SDVariable[]{in}); } + public NoOp(INDArray in) { + addInputArgument(in); + } + @Override public List doDiff(List f1) { return Collections.singletonList(f1.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java index f1b7f7398..cea3b388a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BiasAdd.java @@ -85,7 +85,7 @@ public class BiasAdd extends DynamicCustomOp { @Override public List doDiff(List gradient){ - return Arrays.asList(f().biasAddBp(arg(0), arg(1), gradient.get(0), nchw)); + return new BiasAddGrad(sameDiff, arg(0), arg(1), gradient.get(0), nchw).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java index 1bb451bf1..c1aff757d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/BaseCompatOp.java @@ -25,6 +25,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.tensorflow.framework.AttrValue; @@ -41,6 +42,10 @@ public abstract class BaseCompatOp extends DynamicCustomOp { super(null, sameDiff, inputs); } + public BaseCompatOp(INDArray... inputs) { + addInputArgument(inputs); + } + public BaseCompatOp(){ } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java index 993a5b11e..9adbd78df 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Merge.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op.Type; import org.tensorflow.framework.AttrValue; @@ -36,6 +37,10 @@ public class Merge extends BaseCompatOp { super(sd, inputs); } + public Merge(INDArray... inputs) { + super(inputs); + } + public Merge(){ } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/StopGradient.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/StopGradient.java index c7d90e4c8..f302c752a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/StopGradient.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/StopGradient.java @@ -50,6 +50,6 @@ public class StopGradient extends BaseDynamicTransformOp { @Override public List doDiff(List gradients){ - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java index 1b6c2f5e2..a7804f39f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/controlflow/compat/Switch.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op.Type; import org.tensorflow.framework.AttrValue; @@ -44,6 +45,10 @@ public class Switch extends BaseCompatOp { this.predicate = predicate; } + public Switch(INDArray input, INDArray predicate) { + addInputArgument(input, predicate); + } + public Switch(){ } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java index b9c3962aa..181321d4f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java @@ -74,6 +74,6 @@ public class IAMax extends BaseIndexAccumulation { @Override public List doDiff(List grad){ - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java index 63f40ee6c..760fca314 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java @@ -76,6 +76,6 @@ public class IAMin extends BaseIndexAccumulation { @Override public List doDiff(List grad){ - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java index 8b7872b49..c01be78f9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java @@ -83,6 +83,6 @@ public class IMax extends BaseIndexAccumulation { @Override public List doDiff(List f1) { //Not differentiable, but (assuming no ties) output does not change for a given infinitesimal change in the input - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java index 06b3deb1c..e668f1ee0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java @@ -77,6 +77,6 @@ public class IMin extends BaseIndexAccumulation { @Override public List doDiff(List f1) { //Not differentiable, but (assuming no ties) output does not change for a given infinitesimal change in the input - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java index 9635c6f36..5417b14cd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Conv2D.java @@ -60,7 +60,6 @@ public class Conv2D extends DynamicCustomOp { SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig); } - @Builder(builderMethodName = "sameDiffBuilder") public Conv2D(SameDiff sameDiff, SDVariable[] inputFunctions, @@ -71,7 +70,7 @@ public class Conv2D extends DynamicCustomOp { } public Conv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ - super(inputs, outputs); + super(inputs, outputs); initConfig(config); } @@ -103,7 +102,8 @@ public class Conv2D extends DynamicCustomOp { config.getDH(), config.getDW(), ArrayUtil.fromBoolean(config.isSameMode()), - config.getDataFormat().equalsIgnoreCase(Conv2DConfig.NCHW) ? 0 : 1); + config.getDataFormat().equalsIgnoreCase("NCHW") ? 0 : 1, + config.getWeightsFormat().ordinal()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java index 436659443..d0b04b36a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DeConv3D.java @@ -161,8 +161,7 @@ public class DeConv3D extends DynamicCustomOp { @Override public List doDiff(List f1) { SDVariable bias = args().length > 2 ? arg(2) : null; - SDVariable[] outVars = f().deconv3dDerivative(arg(0), arg(1), bias, f1.get(0), config); - return Arrays.asList(outVars); + return new DeConv3DDerivative(sameDiff, arg(0), arg(1), bias, f1.get(0), config).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Im2col.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Im2col.java index 46f5f4e79..86bfbacc9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Im2col.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Im2col.java @@ -90,7 +90,7 @@ public class Im2col extends DynamicCustomOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().im2ColBp(arg(), grad.get(0), conv2DConfig)); + return new Im2colBp(sameDiff, arg(), grad.get(0), conv2DConfig).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java index df345a2f3..3370b6f30 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/Upsampling2d.java @@ -99,7 +99,7 @@ public class Upsampling2d extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().upsampling2dBp(arg(), f1.get(0), nchw, scaleH, scaleW)); + return new Upsampling2dDerivative(sameDiff, arg(), f1.get(0), nchw, scaleH, scaleW).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java index 40a2a3908..92701a696 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/config/Conv2DConfig.java @@ -22,6 +22,7 @@ import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; import org.nd4j.base.Preconditions; +import org.nd4j.enums.WeightsFormat; import org.nd4j.linalg.util.ConvConfigUtil; @Data @@ -50,9 +51,11 @@ public class Conv2DConfig extends BaseConvolutionConfig { private boolean isSameMode; @Builder.Default private String dataFormat = NCHW; + @Builder.Default + private WeightsFormat weightsFormat = WeightsFormat.YXIO; public Conv2DConfig(long kH, long kW, long sH, long sW, long pH, long pW, long dH, long dW, boolean isSameMode, - String dataFormat) { + String dataFormat, WeightsFormat weightsFormat) { this.kH = kH; this.kW = kW; @@ -64,6 +67,7 @@ public class Conv2DConfig extends BaseConvolutionConfig { this.dW = dW; this.isSameMode = isSameMode; this.dataFormat = dataFormat; + this.weightsFormat = weightsFormat; validate(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java index adc59e4e0..4f6539eee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/AbsoluteDifferenceLoss.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.AbsoluteDifferenceLossBp; import java.util.Arrays; import java.util.List; @@ -58,7 +59,6 @@ public class AbsoluteDifferenceLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossAbsoluteDifferenceBP(arg(2), arg(0), arg(1), lossReduce); - return Arrays.asList(grads); + return new AbsoluteDifferenceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java index 7faa5f6b0..432910391 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/CosineDistanceLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.CosineDistanceLossBp; import java.util.Arrays; import java.util.List; @@ -61,8 +62,7 @@ public class CosineDistanceLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient. //Args are: predictions, weights, label - SDVariable[] grads = f().lossCosineDistanceBp(arg(2), arg(0), arg(1), lossReduce, dimension); - return Arrays.asList(grads); + return new CosineDistanceLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), dimension).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java index 5d85e4933..d021623d5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HingeLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.HingeLossBp; import java.util.Arrays; import java.util.List; @@ -56,8 +57,7 @@ public class HingeLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossHingeBp(arg(2), arg(0), arg(1), lossReduce); - return Arrays.asList(grads); + return new HingeLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java index f08d90566..acb74c04c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/HuberLoss.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.HuberLossBp; import java.util.Arrays; import java.util.List; @@ -63,8 +64,7 @@ public class HuberLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossHuberBp(arg(2), arg(0), arg(1), lossReduce, delta); - return Arrays.asList(grads); + return new HuberLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), delta).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java index e1fe56e5f..d36d36c2f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/L2Loss.java @@ -62,6 +62,6 @@ public class L2Loss extends DynamicCustomOp { public List doDiff(List grad){ //L2 loss: L = 1/2 * sum(x_i^2) //dL/dxi = xi - return Collections.singletonList(f().identity(arg())); + return Collections.singletonList(sameDiff.identity(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java index a7a15f1b5..c13634ee1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.LogLossBp; import java.util.Arrays; import java.util.List; @@ -64,8 +65,7 @@ public class LogLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossLogBp(arg(2), arg(0), arg(1), lossReduce, epsilon); - return Arrays.asList(grads); + return new LogLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), epsilon).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java index a893e3f4a..2ec6e54b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/LogPoissonLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.LogPoissonLossBp; import java.util.Arrays; import java.util.List; @@ -73,14 +74,7 @@ public class LogPoissonLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - - SDVariable[] grads; - if(full) { - grads = f().lossLogPoissonFullBp(arg(2), arg(0), arg(1), lossReduce); - }else{ - grads = f().lossLogPoissonBp(arg(2), arg(0), arg(1), lossReduce); - } - return Arrays.asList(grads); + return new LogPoissonLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), full).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java index 6c3c5d01b..676eec5e0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanPairwiseSquaredErrorLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.MeanPairwiseSquaredErrorLossBp; import java.util.Arrays; import java.util.List; @@ -54,7 +55,6 @@ public class MeanPairwiseSquaredErrorLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossMeanPairwiseSquaredErrorBp(arg(2), arg(0), arg(1), lossReduce); - return Arrays.asList(grads); + return new MeanPairwiseSquaredErrorLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java index a9cf27584..c40d9e432 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/MeanSquaredErrorLoss.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.loss.LossReduce; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.loss.bp.MeanSquaredErrorLossBp; import java.util.Arrays; import java.util.List; @@ -56,8 +57,7 @@ public class MeanSquaredErrorLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossMeanSquaredErrorBp(arg(2), arg(0), arg(1), lossReduce); - return Arrays.asList(grads); + return new MeanSquaredErrorLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java index 214380a8c..862b405d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SigmoidCrossEntropyLoss.java @@ -27,6 +27,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.impl.loss.bp.SigmoidCrossEntropyLossBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -80,7 +81,6 @@ public class SigmoidCrossEntropyLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossSigmoidCrossEntropyBp(arg(2), arg(0), arg(1), lossReduce, labelSmoothing); - return Arrays.asList(grads); + return new SigmoidCrossEntropyLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), labelSmoothing).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java index 57576b78f..e97427e92 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyLoss.java @@ -25,6 +25,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyLossBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -99,7 +100,6 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss { public List doDiff(List grad){ //No external gradient //Args are: predictions, weights, label - SDVariable[] grads = f().lossSoftmaxCrossEntropyBp(arg(2), arg(0), arg(1), lossReduce, labelSmoothing); - return Arrays.asList(grads); + return new SoftmaxCrossEntropyLossBp(sameDiff, lossReduce, arg(0), arg(1), arg(2), labelSmoothing).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java index 3ef7de264..defb8292b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SoftmaxCrossEntropyWithLogitsLoss.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.loss.bp.SoftmaxCrossEntropyWithLogitsLossBp; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; @@ -73,8 +74,6 @@ public class SoftmaxCrossEntropyWithLogitsLoss extends DynamicCustomOp { public List doDiff(List grad){ //No external gradient //Args: logits, weigths, label - SDVariable[] args = args(); - SDVariable[] grads = f().lossSoftmaxCrossEntropyWithLogitsBp(arg(0), arg(1), classesDim); - return Arrays.asList(grads); + return new SoftmaxCrossEntropyWithLogitsLossBp(sameDiff, arg(0), arg(1), classesDim).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java index a0f3288a9..c58933134 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/loss/SparseSoftmaxCrossEntropyLossWithLogits.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.Op; +import org.nd4j.linalg.api.ops.impl.loss.bp.SparseSoftmaxCrossEntropyLossWithLogitsBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -96,7 +97,8 @@ public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp { @Override public List doDiff(List grad){ //args: label, logits - SDVariable[] ret = f().lossSparseSoftmaxCrossEntropyBp(arg(1), arg(0)); - return Arrays.asList(f().zerosLike(arg(0)), ret[0]); + SDVariable labelsGrad = sameDiff.zerosLike(arg(0)); + SDVariable logitsGrad = new SparseSoftmaxCrossEntropyLossWithLogitsBp(sameDiff, arg(1), arg(0)).outputVariable(); + return Arrays.asList(labelsGrad, logitsGrad); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java index 46310893d..479c794c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Mmul.java @@ -266,7 +266,7 @@ public class Mmul extends DynamicCustomOp { @Override public List doDiff(List gradients) { - return sameDiff.f().mmulBp(larg(),rarg(), gradients.get(0), mt); + return Arrays.asList(new MmulBp(sameDiff, larg(), rarg(), gradients.get(0), mt).outputVariables()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java index bf3cb4af1..89ba1549b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/Moments.java @@ -25,6 +25,8 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -83,8 +85,8 @@ public class Moments extends DynamicCustomOp { public List doDiff(List grad){ SDVariable dLdMean = grad.get(0); SDVariable dLdVar = grad.get(1); //Note: non-bias-corrected variance - SDVariable meanBp = f().meanBp(arg(), dLdMean, false, axes); - SDVariable varBp = f().varianceBp(arg(), dLdVar, false, false, axes); + SDVariable meanBp = new MeanBp(sameDiff, arg(), dLdMean, false, axes).outputVariable(); + SDVariable varBp = new VarianceBp(sameDiff, arg(), dLdVar, false, false, axes).outputVariable(); return Collections.singletonList(meanBp.add(varBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index c613f107f..f58347492 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -138,13 +138,13 @@ public class TensorMmul extends DynamicCustomOp { //tensor matrix multiply gradient wrt second variable int[] firstPerm = argsort(combine(deletedAxes[0],keep(argsort(sumAxes[1]),sumAxes[0]))); SDVariable firstResult = doTensorMmul(i_v1.get(0), rarg(), firstAxes); - SDVariable permuted = f().permute(firstResult,firstPerm); + SDVariable permuted = sameDiff.permute(firstResult,firstPerm); ret.add(permuted); //tensor matrix multiply gradient wrt first variable int[] secondPerm = argsort(combine(keep(argsort(sumAxes[0]),sumAxes[1]),deletedAxes[1])); SDVariable secondResult = doTensorMmul(i_v1.get(0), larg(), secondAxes); - SDVariable secondPermuted = f().permute(secondResult,secondPerm); + SDVariable secondPermuted = sameDiff.permute(secondResult,secondPerm); ret.add(secondPermuted); return ret; } @@ -210,7 +210,7 @@ public class TensorMmul extends DynamicCustomOp { } - int[] newShapeB = {n3, -1}; + long[] newShapeB = {n3, -1}; long[] oldShapeB; if (listB.size() == 0) { oldShapeB = new long[] {1}; @@ -221,16 +221,12 @@ public class TensorMmul extends DynamicCustomOp { } - SDVariable at = f() - .reshape(f().permute - (a,newAxesA),newShapeA); - SDVariable bt = f() - .reshape(f() - .permute(b,newAxesB),newShapeB); + SDVariable at = sameDiff.reshape(sameDiff.permute(a,newAxesA),newShapeA); + SDVariable bt = sameDiff.reshape(sameDiff.permute(b,newAxesB),newShapeB); - SDVariable ret = f().mmul(at,bt); + SDVariable ret = sameDiff.mmul(at,bt); long[] aPlusB = Longs.concat(oldShapeA, oldShapeB); - return f().reshape(ret, aPlusB); + return sameDiff.reshape(ret, aPlusB); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java index a465728d1..8aa12d4d3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/All.java @@ -57,7 +57,7 @@ public class All extends BaseReduceBoolOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java index d4522ca69..4d26e5b70 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/Any.java @@ -57,7 +57,7 @@ public class Any extends BaseReduceBoolOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java index cb93a832e..5dfc23f8e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsInf.java @@ -68,7 +68,7 @@ public class IsInf extends BaseReduceBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java index c8cd72f2c..a78ae8bd5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/bool/IsNaN.java @@ -68,7 +68,7 @@ public class IsNaN extends BaseReduceBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java index 26eabf0ff..edc3298b5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/custom/LogSumExp.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -82,10 +83,10 @@ public class LogSumExp extends DynamicCustomOp { //z = log(sum_i exp(x_i)) = log(s) //dL/dx = dL/dz * dz/ds * ds/dx //dz/ds = 1/s - SDVariable exp = f().exp(arg()); + SDVariable exp = sameDiff.math.exp(arg()); SDVariable sumExp = exp.sum(dimensions); SDVariable gradProd = f1.get(0).div(sumExp); - SDVariable dSumExpdx = f().sumBp(arg(), gradProd, keepDims, dimensions).mul(exp); + SDVariable dSumExpdx = new SumBp(sameDiff, arg(), gradProd, keepDims, dimensions).outputVariable().mul(exp); return Collections.singletonList(dSumExpdx); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/AMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/AMean.java index 47cc728ab..e9481fa81 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/AMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/AMean.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; import java.util.Collections; import java.util.List; @@ -73,7 +74,7 @@ public class AMean extends BaseReduceFloatOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable meanBp = f().meanBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable meanBp = new MeanBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(meanBp)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java index bb0dd4997..913a573db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Entropy.java @@ -16,12 +16,11 @@ package org.nd4j.linalg.api.ops.impl.reduce.floating; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -70,13 +69,13 @@ public class Entropy extends BaseReduceFloatOp { //Then we can do sumBp(z, -dL/dOut) //Note d/dx(x*log(x)) = log(x)+1 - return grad(f(), arg(), f1.get(0), dimensions); + return grad(sameDiff, arg(), f1.get(0), dimensions); } - public static List grad(DifferentialFunctionFactory f, SDVariable arg, SDVariable grad, int[] dimensions){ - SDVariable logx = f.log(arg); + public static List grad(SameDiff sd, SDVariable arg, SDVariable grad, int[] dimensions){ + SDVariable logx = sd.math.log(arg); SDVariable xLogX = arg.mul(logx); - SDVariable sumBp = f.sumBp(xLogX, grad.neg(), false, dimensions); + SDVariable sumBp = new SumBp(sd, xLogX, grad.neg(), false, dimensions).outputVariable(); return Collections.singletonList(sumBp.mul(logx.add(1.0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/LogEntropy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/LogEntropy.java index 52970cc33..837d89c3a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/LogEntropy.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/LogEntropy.java @@ -70,7 +70,7 @@ public class LogEntropy extends BaseReduceFloatOp { @Override public List doDiff(List f1) { //If y=log(x), and x=entropy(in) then dL/dx = dL/dy * dy/dx; d(log(x))/dx = 1/x - List entropyGrad = Entropy.grad(f(), arg(), f1.get(0), dimensions); - return Collections.singletonList(entropyGrad.get(0).div(f().exp(outputVariable()))); + List entropyGrad = Entropy.grad(sameDiff, arg(), f1.get(0), dimensions); + return Collections.singletonList(entropyGrad.get(0).div(sameDiff.math.exp(outputVariable()))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java index bf15f94d4..6309ccf28 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Mean.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MeanBp; import java.util.Collections; import java.util.List; @@ -67,7 +68,7 @@ public class Mean extends BaseReduceFloatOp { public List doDiff(List i_v1) { //If out = mean(in), then dL/dIn = 1/N * dL/dOut (broadcast to appropriate shape) //Note that N differs for "along dimension" vs. "whole array" reduce cases - return Collections.singletonList(f().meanBp(arg(), i_v1.get(0), keepDims, dimensions)); + return new MeanBp(sameDiff, arg(), i_v1.get(0), keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java index a2ba88927..96222d7c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm1.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm1Bp; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Collections; @@ -80,6 +81,6 @@ public class Norm1 extends BaseReduceFloatOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().norm1Bp(arg(), grad.get(0), keepDims, dimensions)); + return new Norm1Bp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java index f61c0dc43..be517f5e8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/Norm2.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.Norm2Bp; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Collections; @@ -72,7 +73,7 @@ public class Norm2 extends BaseReduceFloatOp { @Override public List doDiff(List grad) { //d norm2(in)/dx = x / norm2(in) - return Collections.singletonList(f().norm2Bp(arg(), grad.get(0), keepDims, dimensions)); + return new Norm2Bp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java index ece542857..a7cf398f5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/NormMax.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.NormMaxBp; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.Collections; @@ -77,7 +78,7 @@ public class NormMax extends BaseReduceFloatOp { public List doDiff(List grad) { //maxnorm(in) = max_i |x_i| //d maxnorm(in)/dx = 0 if x_i is not the max, or d|x|/dx otherwise - return Collections.singletonList(f().normmaxBp(arg(), grad.get(0), keepDims, dimensions)); + return new NormMaxBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/ShannonEntropy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/ShannonEntropy.java index 44504f855..963224ed8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/ShannonEntropy.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/ShannonEntropy.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -68,10 +69,10 @@ public class ShannonEntropy extends BaseReduceFloatOp { //Then we can do sumBp(z, -dL/dOut) //Note d/dx(x*log2(x)) = (log(x)+1)/log(2) - SDVariable log2x = f().log(arg(),2); - SDVariable logx = f().log(arg()); + SDVariable log2x = sameDiff.math.log(arg(),2); + SDVariable logx = sameDiff.math.log(arg()); SDVariable xLog2X = arg().mul(log2x); - SDVariable sumBp = f().sumBp(xLog2X, f1.get(0).neg(), false, dimensions); + SDVariable sumBp = new SumBp(sameDiff, xLog2X, f1.get(0).neg(), false, dimensions).outputVariable(); return Collections.singletonList(sumBp.mul(logx.add(1.0)).div(Math.log(2.0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java index b11fe5b1f..2af86c181 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/floating/SquaredNorm.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceFloatOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SquaredNormBp; import java.util.Collections; import java.util.List; @@ -69,6 +70,6 @@ public class SquaredNorm extends BaseReduceFloatOp { @Override public List doDiff(List grad){ - return Collections.singletonList(f().squaredNormBp(arg(), grad.get(0), keepDims, dimensions)); + return new SquaredNormBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java index d27215a80..7376b0708 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountNonZero.java @@ -56,7 +56,7 @@ public class CountNonZero extends BaseReduceLongOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountZero.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountZero.java index db13dfc85..27476cabc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountZero.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/longer/CountZero.java @@ -67,7 +67,7 @@ public class CountZero extends BaseReduceLongOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMax.java index a6533441a..cadc77d4f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMax.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; import java.util.Collections; import java.util.List; @@ -65,7 +66,7 @@ public class AMax extends BaseReduceSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable maxBp = f().maxBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable maxBp = new MaxBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(maxBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMin.java index 20a8be906..a01c9c1f5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/AMin.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; import java.util.Collections; import java.util.List; @@ -69,7 +70,7 @@ public class AMin extends BaseReduceSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable minBp = new MinBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(minBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/ASum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/ASum.java index 17d8a0bde..1a15c32ac 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/ASum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/ASum.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -72,7 +73,7 @@ public class ASum extends BaseReduceSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable meanBp = f().sumBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable meanBp = new SumBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(meanBp)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java index a29384a42..8c4563c95 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Max.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; import java.util.Collections; import java.util.List; @@ -79,7 +80,7 @@ public class Max extends BaseReduceSameOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().maxBp(arg(), grad.get(0), keepDims, dimensions)); + return new MaxBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java index 99c1e038b..1d644b671 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Min.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; import java.util.Collections; import java.util.List; @@ -77,6 +78,6 @@ public class Min extends BaseReduceSameOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().minBp(arg(), grad.get(0), keepDims, dimensions)); + return new MinBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java index b5073d0f9..0247e3169 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Prod.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.BaseNDArray; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.ProdBp; import java.util.Collections; import java.util.List; @@ -82,6 +83,6 @@ public class Prod extends BaseReduceSameOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().prodBp(arg(), grad.get(0), keepDims, dimensions)); + return new ProdBp(sameDiff, arg(), grad.get(0), keepDims, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java index 859b89dac..e6fa79bb0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/same/Sum.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.SumBp; import java.util.Collections; import java.util.List; @@ -76,7 +77,7 @@ public class Sum extends BaseReduceSameOp { // dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * 1 // But broadcast to shape of the input - return Collections.singletonList(f().sumBp(arg(), i_v1.get(0), keepDims, dimensions)); + return new SumBp(sameDiff, arg(), i_v1.get(0), keepDims, dimensions).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java index 44f1c49fc..a5aab468b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineDistance.java @@ -84,7 +84,7 @@ public class CosineDistance extends BaseReduce3Op { //Cosine distance = 1 - cosine similarity //Therefore: just need to negate gradients from cosine similarity... - List diff = CosineSimilarity.doDiff(sameDiff, f(), larg(), rarg(), i_v1.get(0), keepDims, dimensions); - return Arrays.asList(f().neg(diff.get(0)), f().neg(diff.get(1))); + List diff = CosineSimilarity.doDiff(sameDiff, larg(), rarg(), i_v1.get(0), keepDims, dimensions); + return Arrays.asList(sameDiff.math.neg(diff.get(0)), sameDiff.math.neg(diff.get(1))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java index 27c14473d..b6edbe6fa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/CosineSimilarity.java @@ -16,9 +16,9 @@ package org.nd4j.linalg.api.ops.impl.reduce3; -import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -93,14 +93,14 @@ public class CosineSimilarity extends BaseReduce3Op { //Then: // dc(x,y)/dx_i = 1/b * (y - x * a / (l2(x))^2) - return doDiff(sameDiff, f(), larg(), rarg(), i_v1.get(0), keepDims, dimensions); + return doDiff(sameDiff, larg(), rarg(), i_v1.get(0), keepDims, dimensions); } - public static List doDiff(SameDiff sameDiff, DifferentialFunctionFactory f, SDVariable x, SDVariable y, + public static List doDiff(SameDiff sameDiff, SDVariable x, SDVariable y, SDVariable gradOut, boolean keepDims, int... dimensions){ SDVariable a = sameDiff.sum(x.mul(y),true, dimensions); - SDVariable l2x = f.norm2(x, true, dimensions); - SDVariable l2y = f.norm2(y, true, dimensions); + SDVariable l2x = sameDiff.norm2(x, true, dimensions); + SDVariable l2y = sameDiff.norm2(y, true, dimensions); SDVariable b = l2x.mul(l2y); SDVariable l2xSq = sameDiff.math().square(l2x); @@ -110,7 +110,7 @@ public class CosineSimilarity extends BaseReduce3Op { //keepDims or full array reduction broadcastableGrad = gradOut; } else { - broadcastableGrad = sameDiff.f().reductionBroadcastableWithOrigShape(x, sameDiff.constant(Nd4j.createFromArray(dimensions)), gradOut); + broadcastableGrad = SameDiffUtils.reductionBroadcastableWithOrigShape(x, sameDiff.constant(Nd4j.createFromArray(dimensions)), gradOut); } SDVariable dcdx = y.sub(x.mul(a).div(l2xSq)).div(b); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java index 85f0b3e15..bdb172924 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/Dot.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.reduce3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.reduce.bp.DotBp; import java.util.Arrays; import java.util.List; @@ -86,6 +87,6 @@ public class Dot extends BaseReduce3Op { @Override public List doDiff(List f1) { //TODO KEEP DIMS - return Arrays.asList(f().dotBp(arg(0), arg(1), f1.get(0), false, dimensions)); + return new DotBp(sameDiff, arg(0), arg(1), f1.get(0), false, dimensions).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java index a25ba6d52..97ccd81e6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/EuclideanDistance.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -89,11 +90,11 @@ public class EuclideanDistance extends BaseReduce3Op { SDVariable divBroadcastable = i_v1.get(0).div(euc); if(!keepDims && !(dimensions == null || dimensions.length == 0 || (dimensions.length == 1 && dimensions[0] == Integer.MAX_VALUE))){ //Not keep dims, and not full array reduction -> need to make broadcastable - divBroadcastable = f().reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), divBroadcastable); + divBroadcastable = SameDiffUtils.reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), divBroadcastable); } SDVariable gradX = difference.mul(divBroadcastable); - SDVariable gradY = f().neg(gradX); + SDVariable gradY = sameDiff.math.neg(gradX); return Arrays.asList(gradX, gradY); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java index 994003e78..c520a7c18 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/JaccardDistance.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; @@ -90,18 +91,18 @@ public class JaccardDistance extends BaseReduce3Op { //Jaccard distance: https://en.wikipedia.org/wiki/Jaccard_index#Generalized_Jaccard_similarity_and_distance //J(x,y) = 1 - sum_i min(x_i, y_i) / sum_i max(x_i, y_i) - SDVariable min = f().min(larg(), rarg()); - SDVariable max = f().max(larg(), rarg()); + SDVariable min = sameDiff.math.min(larg(), rarg()); + SDVariable max = sameDiff.math.max(larg(), rarg()); SDVariable sumMax = max.sum(true, dimensions); SDVariable sumMin = min.sum(true, dimensions); DataType d = arg().dataType(); - SDVariable xIsMin = f().eq(min, larg()).castTo(d); - SDVariable xIsMax = f().eq(max, larg()).castTo(d); - SDVariable yIsMin = f().eq(min, rarg()).castTo(d); - SDVariable yIsMax = f().eq(max, rarg()).castTo(d); + SDVariable xIsMin = sameDiff.eq(min, larg()).castTo(d); + SDVariable xIsMax = sameDiff.eq(max, larg()).castTo(d); + SDVariable yIsMin = sameDiff.eq(min, rarg()).castTo(d); + SDVariable yIsMax = sameDiff.eq(max, rarg()).castTo(d); - SDVariable sqSumMax = f().square(sumMax); + SDVariable sqSumMax = sameDiff.math.square(sumMax); SDVariable dldx = xIsMax.mul(sumMin).sub(xIsMin.mul(sumMax)).div(sqSumMax); SDVariable dldy = yIsMax.mul(sumMin).sub(yIsMin.mul(sumMax)).div(sqSumMax); @@ -110,7 +111,7 @@ public class JaccardDistance extends BaseReduce3Op { //KeepDims or full array reduction - already broadcastable bcGradOut = f1.get(0); } else { - bcGradOut = sameDiff.f().reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), f1.get(0)); + bcGradOut = SameDiffUtils.reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), f1.get(0)); } return Arrays.asList(dldx.mul(bcGradOut), dldy.mul(bcGradOut)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java index 0c007a261..9fdea3afb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce3/ManhattanDistance.java @@ -18,6 +18,7 @@ package org.nd4j.linalg.api.ops.impl.reduce3; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -86,11 +87,11 @@ public class ManhattanDistance extends BaseReduce3Op { //keepDims or full array reduction gradBroadcastable = i_v1.get(0); } else { - gradBroadcastable = sameDiff.f().reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), i_v1.get(0)); + gradBroadcastable = SameDiffUtils.reductionBroadcastableWithOrigShape(arg(), sameDiff.constant(Nd4j.createFromArray(dimensions)), i_v1.get(0)); } SDVariable gradX = sameDiff.math().sign(difference).mul(gradBroadcastable); - SDVariable gradY = f().neg(gradX); + SDVariable gradY = sameDiff.math().neg(gradX); return Arrays.asList(gradX, gradY); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java index 000b0414c..44514ee1a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java @@ -23,6 +23,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -108,7 +109,7 @@ public class LeakyReLU extends BaseScalarOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().leakyReluBp(arg(), i_v.get(0), alpha)); + return new LeakyReLUBp(sameDiff, arg(), i_v.get(0), alpha).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java index f9e30be9c..4b9b37026 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/PRelu.java @@ -29,6 +29,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.PReluBp; /** * Parameterized ReLU op @@ -80,6 +81,6 @@ public class PRelu extends DynamicCustomOp { @Override public List doDiff(List i_v) { - return Arrays.asList(f().preluBp(arg(0), arg(1), i_v.get(0), sharedAxes)); + return new PReluBp(sameDiff, arg(0), arg(1), i_v.get(0), sharedAxes).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java index 5cfab3768..ec15ea537 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Pow.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -87,7 +88,7 @@ public class Pow extends BaseScalarOp { @Override public List doDiff(List i_v1) { - SDVariable g = f().powDerivative(arg(), this.pow).mul(i_v1.get(0)); - return Arrays.asList(g); + SDVariable g = new PowDerivative(sameDiff, arg(), false, this.pow).outputVariable().mul(i_v1.get(0)); + return Collections.singletonList(g); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java index 944d4d095..98df920bc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp; import java.util.Arrays; import java.util.Collections; @@ -81,6 +82,6 @@ public class RectifiedLinear extends BaseScalarOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().thresholdReluBp(arg(), i_v.get(0), scalarValue.getDouble(0))); + return new ThresholdReluBp(sameDiff, arg(), i_v.get(0), scalarValue.getDouble(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java index c80d3c8f9..9b11925c3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Relu6.java @@ -23,6 +23,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -99,6 +100,6 @@ public class Relu6 extends BaseScalarOp { @Override public List doDiff(List i_v) { SDVariable dLdOut = i_v.get(0); - return Collections.singletonList(f().relu6Derivative(arg(), dLdOut, scalarValue.getDouble(0))); + return new Relu6Derivative(sameDiff, arg(), dLdOut, scalarValue.getDouble(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java index 3aa31771a..f8831c68e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarAdd.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.scalar; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -49,9 +50,12 @@ public class ScalarAdd extends BaseScalarOp { this(arr, 0); } + public ScalarAdd(@NonNull SameDiff sameDiff, @NonNull SDVariable i_v, Number scalar) { + this(sameDiff, i_v, scalar, false); + } + public ScalarAdd(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace) { super(sameDiff, i_v, scalar, inPlace); - } public ScalarAdd(SameDiff sameDiff, SDVariable i_v, Number scalar, boolean inPlace, Object[] extraArgs) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseDivision.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseDivision.java index 1fec8f808..463012875 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseDivision.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseDivision.java @@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -73,8 +74,8 @@ public class ScalarReverseDivision extends BaseScalarOp { @Override public List doDiff(List i_v1) { - SDVariable g = f().rdiv(f().pow(arg(), 2), -scalarValue.getDouble(0)).mul(i_v1.get(0)); - return Arrays.asList(g); + SDVariable g = sameDiff.math.rdiv(sameDiff.math.pow(arg(), 2), -scalarValue.getDouble(0)).mul(i_v1.get(0)); + return Collections.singletonList(g); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseSubtraction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseSubtraction.java index d362620e4..972f4ec10 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseSubtraction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarReverseSubtraction.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseScalarOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -79,8 +80,8 @@ public class ScalarReverseSubtraction extends BaseScalarOp { @Override public List doDiff(List i_v1) { - SDVariable g = f().neg(i_v1.get(0)); - return Arrays.asList(g); + SDVariable g = sameDiff.math.neg(i_v1.get(0)); + return Collections.singletonList(g); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSet.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSet.java index a8ec8c7f3..d3a8c7f67 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSet.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/ScalarSet.java @@ -76,7 +76,7 @@ public class ScalarSet extends BaseScalarOp { @Override public List doDiff(List i_v1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java index 65f653d64..04bd39622 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/Step.java @@ -96,6 +96,6 @@ public class Step extends BaseScalarOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java index 160556867..1846ab8f8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterAdd.java @@ -88,9 +88,9 @@ public class ScatterAdd extends DynamicCustomOp { List ret = new ArrayList<>(3); ret.add(gradOut.get(0)); //Reference array - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gather = f().gather(gradOut.get(0), arg(1), 0); //Updates + SDVariable gather = sameDiff.gather(gradOut.get(0), arg(1), 0); //Updates ret.add(gather); return ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java index 5d6b60c88..75badc9c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterDiv.java @@ -77,13 +77,13 @@ public class ScatterDiv extends DynamicCustomOp { SDVariable updates = arg(2); List ret = new ArrayList<>(3); - SDVariable gradRef = f().scatterDiv(gradOut.get(0), indices, updates); + SDVariable gradRef = sameDiff.scatterDiv(gradOut.get(0), indices, updates); ret.add(gradRef); //Reference array - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gatherOutGrad = f().gather(gradOut.get(0), indices, 0); //Updates - SDVariable gatherRef = f().gather(ref, indices, 0); - SDVariable updateGrad = gatherOutGrad.mul(gatherRef).div(f().square(updates)).neg(); + SDVariable gatherOutGrad = sameDiff.gather(gradOut.get(0), indices, 0); //Updates + SDVariable gatherRef = sameDiff.gather(ref, indices, 0); + SDVariable updateGrad = gatherOutGrad.mul(gatherRef).div(sameDiff.math.square(updates)).neg(); ret.add(updateGrad); return ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java index 7f814d928..2dead9742 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMax.java @@ -87,12 +87,12 @@ public class ScatterMax extends DynamicCustomOp { SDVariable notModified = arg(0).eq(outputVariable()).castTo(arg(0).dataType()); //0 if modified, 1 otherwise SDVariable refGrad = gradOut.get(0).mul(notModified); - SDVariable gatherOut = f().gather(outputVariable(), arg(1), 0); - SDVariable gatherGrad = f().gather(gradOut.get(0), arg(1), 0); + SDVariable gatherOut = sameDiff.gather(outputVariable(), arg(1), 0); + SDVariable gatherGrad = sameDiff.gather(gradOut.get(0), arg(1), 0); SDVariable outIsUpdate = gatherOut.eq(arg(2)).castTo(arg(2).dataType()); SDVariable updateGrad = gatherGrad.mul(outIsUpdate); - return Arrays.asList(refGrad, f().zerosLike(arg(1)), updateGrad); + return Arrays.asList(refGrad, sameDiff.zerosLike(arg(1)), updateGrad); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java index 2539a3d56..4af8a2cd2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMin.java @@ -88,12 +88,12 @@ public class ScatterMin extends DynamicCustomOp { SDVariable notModified = arg(0).eq(outputVariable()).castTo(arg(0).dataType()); //0 if modified, 1 otherwise SDVariable refGrad = gradOut.get(0).mul(notModified); - SDVariable gatherOut = f().gather(outputVariable(), arg(1), 0); - SDVariable gatherGrad = f().gather(gradOut.get(0), arg(1), 0); + SDVariable gatherOut = sameDiff.gather(outputVariable(), arg(1), 0); + SDVariable gatherGrad = sameDiff.gather(gradOut.get(0), arg(1), 0); SDVariable outIsUpdate = gatherOut.eq(arg(2)).castTo(arg(2).dataType()); SDVariable updateGrad = gatherGrad.mul(outIsUpdate); - return Arrays.asList(refGrad, f().zerosLike(arg(1)), updateGrad); + return Arrays.asList(refGrad, sameDiff.zerosLike(arg(1)), updateGrad); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java index 411c59188..48e1a00bd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterMul.java @@ -91,12 +91,12 @@ public class ScatterMul extends DynamicCustomOp { SDVariable updates = arg(2); List ret = new ArrayList<>(3); - SDVariable gradRef = f().scatterMul(gradOut.get(0), indices, updates); + SDVariable gradRef = sameDiff.scatterMul(gradOut.get(0), indices, updates); ret.add(gradRef); //Reference array - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gatherOutGrad = f().gather(gradOut.get(0), indices, 0); //Updates - SDVariable gatherRef = f().gather(ref, indices, 0); + SDVariable gatherOutGrad = sameDiff.gather(gradOut.get(0), indices, 0); //Updates + SDVariable gatherRef = sameDiff.gather(ref, indices, 0); SDVariable updateGrad = gatherOutGrad.mul(gatherRef); ret.add(updateGrad); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java index 83c4cc222..f66f7d689 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterSub.java @@ -74,9 +74,9 @@ public class ScatterSub extends DynamicCustomOp { List ret = new ArrayList<>(3); ret.add(gradOut.get(0)); //Reference array - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gather = f().gather(gradOut.get(0), arg(1), 0); //Updates + SDVariable gather = sameDiff.gather(gradOut.get(0), arg(1), 0); //Updates ret.add(gather.neg()); return ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java index 93e1e5995..c5644faa5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scatter/ScatterUpdate.java @@ -98,12 +98,12 @@ public class ScatterUpdate extends DynamicCustomOp { SDVariable updates = arg(2); List ret = new ArrayList<>(3); - SDVariable zerosUpdate = f().zerosLike(updates); - SDVariable gradRef = f().scatterMul(gradOut.get(0), indices, zerosUpdate); //TODO optimize + SDVariable zerosUpdate = sameDiff.zerosLike(updates); + SDVariable gradRef = sameDiff.scatterMul(gradOut.get(0), indices, zerosUpdate); //TODO optimize ret.add(gradRef); //Reference array gradient - ret.add(f().zerosLike(arg(1))); //Indices + ret.add(sameDiff.zerosLike(arg(1))); //Indices - SDVariable gather = f().gather(gradOut.get(0), indices, 0); //Updates + SDVariable gather = sameDiff.gather(gradOut.get(0), indices, 0); //Updates ret.add(gather); return ret; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java index a9f964844..ede375163 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Linspace.java @@ -117,6 +117,6 @@ public class Linspace extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().zerosLike(arg(0)), f().zerosLike(arg(1)), f().zerosLike(arg(2))); + return Arrays.asList(sameDiff.zerosLike(arg(0)), sameDiff.zerosLike(arg(1)), sameDiff.zerosLike(arg(2))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java index cfd0bd7ed..870a72a2b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Permute.java @@ -77,10 +77,10 @@ public class Permute extends Transpose { SDVariable ret; if(args().length == 1) { //Static dimensions - ret = f().permute(i_v.get(0), reverseDims); + ret = sameDiff.permute(i_v.get(0), reverseDims); } else { //Dynamic dimensions - ret = f().permute(i_v.get(0), sameDiff.invertPermutation(arg(1))); + ret = sameDiff.permute(i_v.get(0), sameDiff.invertPermutation(arg(1))); } return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java index 2126dfe27..6c5c0f9d6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Reshape.java @@ -152,8 +152,8 @@ public class Reshape extends DynamicCustomOp { @Override public List doDiff(List i_v) { - SDVariable origShape = f().shape(arg()); - SDVariable ret = f().reshape(i_v.get(0), origShape); + SDVariable origShape = sameDiff.shape(arg()); + SDVariable ret = sameDiff.reshape(i_v.get(0), origShape); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java index 3c3baf1f6..94a6e6c2b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/SequenceMask.java @@ -120,7 +120,7 @@ public class SequenceMask extends DynamicCustomOp { @Override public List doDiff(List grad){ //Input is integer indices - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java index 593830327..d9c4c4578 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/ShapeN.java @@ -65,7 +65,7 @@ public class ShapeN extends DynamicCustomOp { public List doDiff(List i_v) { List out = new ArrayList<>(); for(SDVariable in : args()){ - out.add(f().zerosLike(in)); + out.add(sameDiff.zerosLike(in)); } return out; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java index 46b8f6286..b9bcff540 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Slice.java @@ -25,6 +25,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; import java.util.*; @@ -82,10 +83,10 @@ public class Slice extends DynamicCustomOp { @Override public List doDiff(List grad) { if(args().length == 1) { - return Collections.singletonList(f().sliceBp(arg(), grad.get(0), begin, size)); + return new SliceBp(sameDiff, arg(), grad.get(0), begin, size).outputs(); } else { //Dynamic begin/size - return Collections.singletonList(f().sliceBp(arg(0), grad.get(0), arg(1), arg(2))); + return new SliceBp(sameDiff, arg(0), grad.get(0), arg(1), arg(2)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java index 17a8beb3c..28e92930c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Stack.java @@ -129,7 +129,7 @@ public class Stack extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Arrays.asList(f().unstack(f1.get(0), jaxis, args().length)); + return Arrays.asList(sameDiff.unstack(f1.get(0), jaxis, args().length)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java index 456edfe1c..33c79e217 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/StridedSlice.java @@ -27,6 +27,7 @@ import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.util.ArrayUtil; import org.tensorflow.framework.AttrValue; @@ -259,12 +260,12 @@ public class StridedSlice extends DynamicCustomOp { public List doDiff(List i_v) { if(args().length == 1) { //Array inputs for begin/end/strides - return Collections.singletonList(f().stridedSliceBp(arg(), i_v.get(0), begin, end, strides, beginMask, endMask, - ellipsisMask, newAxisMask, shrinkAxisMask)); + return new StridedSliceBp(sameDiff, arg(), i_v.get(0), begin, end, strides, beginMask, endMask, + ellipsisMask, newAxisMask, shrinkAxisMask).outputs(); } else { //SDVariable inputs for begin/end/strides - return Collections.singletonList(f().stridedSliceBp(arg(), i_v.get(0), arg(1), arg(2), arg(3), beginMask, endMask, - ellipsisMask, newAxisMask, shrinkAxisMask)); + return new StridedSliceBp(sameDiff, arg(), i_v.get(0), arg(1), arg(2), arg(3), beginMask, endMask, + ellipsisMask, newAxisMask, shrinkAxisMask).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java index c2e476f60..e90e31427 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Tile.java @@ -24,6 +24,7 @@ import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -126,9 +127,9 @@ public class Tile extends DynamicCustomOp { @Override public List doDiff(List i_v) { if(jaxis != null){ - return Collections.singletonList(f().tileBp(arg(), i_v.get(0), jaxis)); + return new TileBp(sameDiff, arg(), i_v.get(0), jaxis).outputs(); }else{ - return Collections.singletonList(f().tileBp(arg(0), arg(1), i_v.get(0))); + return new TileBp(sameDiff, arg(0), arg(1), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java index 211fec834..2b6359985 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/StandardDeviation.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.reduce.bp.StandardDeviationBp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -101,7 +102,7 @@ public class StandardDeviation extends Variance { //If out = stdev(in) then: //dL/dIn = dL/dOut * dOut/dIn //dOut/dIn_i = (in_i-mean)/(stdev * (n-1)) - return Collections.singletonList(f().stdBp(arg(), grad.get(0), biasCorrected, keepDims, dimensions)); + return new StandardDeviationBp(sameDiff, arg(), grad.get(0), biasCorrected, keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java index adc92549b..64948880c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/summarystats/Variance.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceOp; import org.nd4j.linalg.api.ops.OpContext; +import org.nd4j.linalg.api.ops.impl.reduce.bp.VarianceBp; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.exception.ND4JIllegalStateException; @@ -115,7 +116,7 @@ public class Variance extends BaseReduceOp { //If out = var(in) then: //dL/dIn = dL/dOut * dOut/dIn // with dOut/dIn = (in-mean) * 2/(n-1) - return Collections.singletonList(f().varianceBp(arg(), grad.get(0), biasCorrected, keepDims, dimensions)); + return new VarianceBp(sameDiff, arg(), grad.get(0), biasCorrected, keepDims, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java index 66eeb9b99..7fed19e02 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Angle.java @@ -44,6 +44,6 @@ public class Angle extends DynamicCustomOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java index b7bd0e0f6..ef8283be8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/Pad.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SDIndex; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; @@ -46,6 +47,24 @@ public class Pad extends DynamicCustomOp { public Pad(){ } + private static Mode adaptMode(PadMode mode) { + Mode legacyMode = Mode.CONSTANT; + + if (mode == PadMode.CONSTANT) { + legacyMode = Mode.CONSTANT; + } + else if (mode == PadMode.REFLECT) { + legacyMode = Mode.REFLECT; + } + else if (mode == PadMode.SYMMETRIC) { + legacyMode = Mode.SYMMETRIC; + } + return legacyMode; + } + + public Pad(SameDiff sd, SDVariable in, SDVariable padding, PadMode mode, double padValue) { + this(sd, in, padding, adaptMode(mode), padValue); + } public Pad(SameDiff sd, SDVariable in, SDVariable padding, Mode mode, double padValue) { super(sd, new SDVariable[]{in, padding}, false); Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); @@ -62,6 +81,10 @@ public class Pad extends DynamicCustomOp { this(in, padding, null, Mode.CONSTANT, padValue); } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, @NonNull PadMode mode, double padValue) { + this(in, padding, null, adaptMode(mode), padValue); + } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull Mode mode, double padValue){ super(null, new INDArray[]{in, padding}, out == null ? null : new INDArray[]{out}); Preconditions.checkState(padding.dataType().isIntType(), "Padding array must be an integer datatype, got %s", padding.dataType()); @@ -70,6 +93,10 @@ public class Pad extends DynamicCustomOp { addTArgument(padValue); } + public Pad(@NonNull INDArray in, @NonNull INDArray padding, INDArray out, @NonNull PadMode mode, double padValue) { + this(in, padding, out, adaptMode(mode), padValue); + } + @Override public String opName(){ return "pad"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java index 218ec66db..54f91f6d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java @@ -73,7 +73,7 @@ public class IsMax extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java index 4dd948b4f..053199731 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/BooleanNot.java @@ -75,6 +75,6 @@ public class BooleanNot extends BaseTransformBoolOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java index 8df844943..8cf81febf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsFinite.java @@ -73,7 +73,7 @@ public class IsFinite extends BaseTransformBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java index 44cb362a4..95d75e9be 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsInf.java @@ -73,7 +73,7 @@ public class IsInf extends BaseTransformBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java index daf9b0ea3..9f8e9ea74 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/bool/IsNaN.java @@ -74,7 +74,7 @@ public class IsNaN extends BaseTransformBoolOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java index fadd8720e..ef1ebb38a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByNorm.java @@ -73,7 +73,7 @@ public class ClipByNorm extends DynamicCustomOp { @Override public List doDiff(List grad) { - return Collections.singletonList(new ClipByNormBp(f().sameDiff(), arg(), grad.get(0), clipValue, dimensions).outputVariable()); + return new ClipByNormBp(sameDiff, arg(), grad.get(0), clipValue, dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java index fa465b251..44cde0abb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByValue.java @@ -83,8 +83,8 @@ public class ClipByValue extends DynamicCustomOp { @Override public List doDiff(List grad) { //dOut/dIn is 0 if clipped, 1 otherwise - SDVariable notClippedLower = f().gt(arg(), clipValueMin).castTo(arg().dataType()); - SDVariable notClippedUpper = f().lt(arg(), clipValueMax).castTo(arg().dataType()); + SDVariable notClippedLower = sameDiff.gt(arg(), clipValueMin).castTo(arg().dataType()); + SDVariable notClippedUpper = sameDiff.lt(arg(), clipValueMax).castTo(arg().dataType()); SDVariable ret = notClippedLower.mul(notClippedUpper).mul(grad.get(0)); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java index d6230e153..e847142a6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ATan2.java @@ -85,17 +85,7 @@ public class ATan2 extends BaseDynamicTransformOp { SDVariable y = larg(); SDVariable x = rarg(); -/* SDVariable r = y.div(x); - - SDVariable dOutdr = f().square(r).add(1.0).rdiv(1.0); - SDVariable drdy = x.rdiv(1.0); - SDVariable drdx = f().neg(y).div(f().square(x)); - - SDVariable xGrad = dOutdr.mul(drdx).mul(i_v.get(0)); - SDVariable yGrad = dOutdr.mul(drdy).mul(i_v.get(0)); -*/ - - val xGrad = f().neg(y.div(x.pow(2).add(y.pow(2)))).mul(i_v.get(0)); + val xGrad = sameDiff.math.neg(y.div(x.pow(2).add(y.pow(2)))).mul(i_v.get(0)); val yGrad = x.div(x.pow(2).add(y.pow(2))).mul(i_v.get(0)); return Arrays.asList(yGrad, xGrad); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java index 35c209870..ca466ae34 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Assign.java @@ -89,7 +89,7 @@ public class Assign extends DynamicCustomOp { @Override public List doDiff(List f1){ //TODO replace with assign backprop op from libnd4j (that handles the broadcast case properly) - return Arrays.asList(f().zerosLike(larg()), f1.get(0)); + return Arrays.asList(sameDiff.zerosLike(larg()), f1.get(0)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java index 0be0b08ad..34e3e5f1d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumProd.java @@ -28,6 +28,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.CumProdBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -142,7 +143,7 @@ public class CumProd extends DynamicCustomOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().cumprodBp(arg(0), grad.get(0), exclusive, reverse, jaxis)); + return new CumProdBp(sameDiff, arg(0), grad.get(0), exclusive, reverse, jaxis).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java index c24693b01..97c53f4e2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CumSum.java @@ -29,6 +29,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.CumSumBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -142,7 +143,7 @@ public class CumSum extends DynamicCustomOp { @Override public List doDiff(List grad) { - return Collections.singletonList(f().cumsumBp(arg(0), grad.get(0), exclusive, reverse, jaxis)); + return new CumSumBp(sameDiff, arg(0), grad.get(0), exclusive, reverse, jaxis).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java index d3a5c9676..300f8277a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DotProductAttention.java @@ -70,7 +70,8 @@ public class DotProductAttention extends DynamicCustomOp { @Override public List doDiff(List gradient) { - return sameDiff.f().dotProductAttentionBp(arg(0), arg(1), arg(2), gradient.get(0), args().length > 3 ? arg(3) : null, scaled); + SDVariable mask = args().length == 4 ? arg(3) : null; + return Arrays.asList(new DotProductAttentionBp(sameDiff, arg(0), arg(1), arg(2), gradient.get(0), mask, scaled).outputVariables()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java index 3efc13af0..718120bf7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicPartition.java @@ -24,6 +24,7 @@ import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -74,7 +75,7 @@ public class DynamicPartition extends DynamicCustomOp { @Override public List doDiff(List i_v) { - return Arrays.asList(f().dynamicPartitionBp(arg(0), arg(1), i_v.toArray(new SDVariable[i_v.size()]), numPartitions)); + return new DynamicPartitionBp(sameDiff, arg(0), arg(1), i_v.toArray(new SDVariable[i_v.size()]), numPartitions).outputs(); } protected void addArgs() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java index 94c34d108..60b2bf942 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/DynamicStitch.java @@ -83,7 +83,7 @@ public class DynamicStitch extends DynamicCustomOp { SDVariable[] partition = sameDiff.dynamicPartition(gradient, partitions, numPartitions); List ret = new ArrayList<>(); for (SDVariable i : indices) - ret.add(f().zerosLike(i)); + ret.add(sameDiff.zerosLike(i)); Collections.addAll(ret, partition); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java index 6048c9dff..9612c4dea 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/InvertPermutation.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -67,8 +68,8 @@ public class InvertPermutation extends BaseDynamicTransformOp { @Override public List doDiff(List grad) { SDVariable gradient = grad.get(0); - SDVariable invertedGradient = f().invertPermutation(gradient, false); - return Arrays.asList(invertedGradient); + SDVariable invertedGradient = sameDiff.invertPermutation(gradient); + return Collections.singletonList(invertedGradient); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java index 0c4990bb2..f16a92318 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java @@ -100,13 +100,11 @@ public class LayerNorm extends DynamicCustomOp { @Override public List doDiff(List gradient) { - SDVariable[] ret; - if(noBias){ - ret = f().layerNormBp(arg(0), arg(1), gradient.get(0), channelsFirst, dimensions); - }else{ - ret = f().layerNormBp(arg(0), arg(1), arg(2), gradient.get(0), channelsFirst, dimensions); + if (noBias) { + return new LayerNormBp(sameDiff, arg(0), arg(1), gradient.get(0), channelsFirst, dimensions).outputs(); + } else { + return new LayerNormBp(sameDiff, arg(0), arg(1), arg(2), gradient.get(0), channelsFirst, dimensions).outputs(); } - return Arrays.asList(ret); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java index 86c9d9c0a..2de57451f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LogSoftMax.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative; import org.nd4j.linalg.factory.Nd4j; import java.util.Collections; @@ -76,11 +77,9 @@ public class LogSoftMax extends DynamicCustomOp { @Override public List doDiff(List i_v) { if(dimension == null) { - SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0)); - return Collections.singletonList(ret); + return new LogSoftMaxDerivative(sameDiff, arg(), i_v.get(0)).outputs(); } else { - SDVariable ret = f().logSoftmaxDerivative(arg(), i_v.get(0), dimension); - return Collections.singletonList(ret); + return new LogSoftMaxDerivative(sameDiff, arg(), i_v.get(0), dimension).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java index 19d139cbb..9f4b97576 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MatrixSetDiag.java @@ -57,8 +57,8 @@ public class MatrixSetDiag extends DynamicCustomOp { @Override public List doDiff(List i_v) { SDVariable grad = i_v.get(0); - SDVariable in1Grad = f().setDiag(grad, sameDiff.zerosLike(arg(1))); - SDVariable in2Grad = f().diagPart(grad); + SDVariable in1Grad = sameDiff.math.setDiag(grad, sameDiff.zerosLike(arg(1))); + SDVariable in2Grad = sameDiff.math.diagPart(grad); return Arrays.asList(in1Grad, in2Grad); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java index 54167bd8b..98765ed96 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MultiHeadDotProductAttention.java @@ -79,7 +79,7 @@ public class MultiHeadDotProductAttention extends DynamicCustomOp { @Override public List doDiff(List gradient) { - return sameDiff.f().multiHeadDotProductAttentionBp(arg(0), arg(1), arg(2), arg(3), arg(4), arg(5), arg(6), gradient.get(0), args().length > 7 ? arg(7) : null, scaled); + return Arrays.asList(new MultiHeadDotProductAttentionBp(sameDiff, arg(0), arg(1), arg(2), arg(3), arg(4), arg(5), arg(6), gradient.get(0), args().length > 7 ? arg(7) : null, scaled).outputVariables()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java index e155a4f2a..0f8286769 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Pow.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.PowBp; import java.util.Arrays; import java.util.Collections; @@ -68,8 +69,7 @@ public class Pow extends DynamicCustomOp { SDVariable dldb = outputVariable().mul(sameDiff.math().log(a)).mul(f1.get(0)); return Arrays.asList(dlda, dldb);*/ - SDVariable[] g = f().powBp(arg(0), arg(1), f1.get(0)); - return Arrays.asList(g); + return new PowBp(sameDiff, arg(0), arg(1), f1.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java index d1648abab..372f96c18 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Reverse.java @@ -100,8 +100,8 @@ public class Reverse extends DynamicCustomOp { @Override public List doDiff(List f1) { - SDVariable ret = f().reverse(f1.get(0), dimensions); - return Arrays.asList(ret); + SDVariable ret = sameDiff.reverse(f1.get(0), dimensions); + return Collections.singletonList(ret); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java index 11897fef8..50332daf6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ReverseSequence.java @@ -115,8 +115,8 @@ public class ReverseSequence extends DynamicCustomOp { @Override public List doDiff(List f1) { - SDVariable ret = f().reverseSequence(f1.get(0), arg(1), seqDim, batchDim); - return Arrays.asList(ret, f().zerosLike(arg(1))); + SDVariable ret = sameDiff.reverseSequence(f1.get(0), arg(1), seqDim, batchDim); + return Arrays.asList(ret, sameDiff.zerosLike(arg(1))); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java index 24c2353c1..737e76f3b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SoftMax.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; import java.util.Collections; import java.util.List; @@ -106,8 +107,7 @@ public class SoftMax extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().softmaxDerivative(arg(), i_v.get(0), this.dimension); - return Collections.singletonList(ret); + return new SoftmaxBp(sameDiff, arg(), i_v.get(0), this.dimension).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java index 467b36a4e..8acef4029 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Standardize.java @@ -63,8 +63,7 @@ public class Standardize extends DynamicCustomOp { @Override public List doDiff(List grad) { - SDVariable ret = f().standardizeBp(arg(0), grad.get(0), dimensions); - return Arrays.asList(ret); + return new StandardizeBp(sameDiff, arg(0), grad.get(0), dimensions).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java index 82e2ae6e3..1688c03c4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp; /** * Threshold ReLU op. The genral case of {@link RectifiedLinear}. @@ -72,6 +73,6 @@ public class ThresholdRelu extends DynamicCustomOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().thresholdReluBp(arg(), f1.get(0), cutoff)); + return new ThresholdReluBp(sameDiff, arg(), f1.get(0), cutoff).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java index 24d79f234..9faca3403 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Trace.java @@ -51,12 +51,12 @@ public class Trace extends DynamicCustomOp { @Override public List doDiff(List gradAtOutput){ - SDVariable rows = f().reshape(f().sizeAt(arg(), -2), new long[]{1}); - SDVariable cols = f().reshape(f().sizeAt(arg(), -1), new long[]{1}); - SDVariable eye = sameDiff.math().eye(/*f().shape(gradAtOutput.get(0)),*/ rows, cols); + SDVariable rows = sameDiff.reshape(sameDiff.sizeAt(arg(), -2), 1); + SDVariable cols = sameDiff.reshape(sameDiff.sizeAt(arg(), -1), 1); + SDVariable eye = sameDiff.math().eye(/*sameDiff.shape(gradAtOutput.get(0)),*/ rows, cols); //Reshape gradient from [x,y,z] to [x,y,z,1,1] - SDVariable reshapedGrad = f().expandDims(gradAtOutput.get(0), -1); - reshapedGrad = f().expandDims(reshapedGrad, -1); + SDVariable reshapedGrad = sameDiff.expandDims(gradAtOutput.get(0), -1); + reshapedGrad = sameDiff.expandDims(reshapedGrad, -1); return Collections.singletonList(reshapedGrad.mul(eye)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java index 5b6cd2517..c98cd7d5b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMax.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMaxBp; import java.util.Arrays; import java.util.Collections; @@ -56,7 +57,7 @@ public class SegmentMax extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentMaxBp(arg(0), arg(1), gradients.get(0))); + return new SegmentMaxBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java index d0a9a6784..eca108b2a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMean.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMeanBp; import java.util.Arrays; import java.util.Collections; @@ -56,7 +57,7 @@ public class SegmentMean extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentMeanBp(arg(0), arg(1), gradients.get(0))); + return new SegmentMeanBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java index 2bc369f2a..f070dc8d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentMin.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentMinBp; import java.util.Arrays; import java.util.Collections; @@ -56,7 +57,7 @@ public class SegmentMin extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentMinBp(arg(0), arg(1), gradients.get(0))); + return new SegmentMinBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java index 3be3625e7..71d0dd2c3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentProd.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentProdBp; import java.util.Arrays; import java.util.Collections; @@ -56,7 +57,7 @@ public class SegmentProd extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentProdBp(arg(0), arg(1), gradients.get(0))); + return new SegmentProdBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java index 5de847162..a74aded65 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/segment/SegmentSum.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.SegmentSumBp; import java.util.Arrays; import java.util.Collections; @@ -56,7 +57,7 @@ public class SegmentSum extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().segmentSumBp(arg(0), arg(1), gradients.get(0))); + return new SegmentSumBp(sameDiff, arg(0), arg(1), gradients.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java index d588ef4a8..b168adb43 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/dtype/Cast.java @@ -129,7 +129,7 @@ public class Cast extends BaseDynamicTransformOp { if(arg().dataType().isFPType()){ return Collections.singletonList(i_v.get(0).castTo(arg().dataType())); } else { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java index df5cdbcc7..9471c8bca 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/floating/RSqrt.java @@ -75,7 +75,7 @@ public class RSqrt extends BaseTransformFloatOp { @Override public List doDiff(List i_v) { - SDVariable xPowNeg32 = f().pow(arg(), -1.5).mul(-0.5); + SDVariable xPowNeg32 = sameDiff.math.pow(arg(), -1.5).mul(-0.5); return Collections.singletonList(i_v.get(0).mul(xPowNeg32)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java index b00b29b75..fdbbafa99 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -41,6 +42,10 @@ public class SELUDerivative extends BaseTransformStrictOp { private static final double SELU_ALPHA = 1.6732632423543772848170429916717; private static final double SELU_LAMBDA = 1.0507009873554804934193349852946; + public SELUDerivative(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public SELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); } @@ -79,9 +84,8 @@ public class SELUDerivative extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().div(arg(),f().seluDerivative(arg())); - - return Arrays.asList(ret); + SDVariable ret = sameDiff.math.div(arg(), new SELUDerivative(sameDiff, arg()).outputVariable()); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java index 16afb4316..2a0d6021a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java @@ -84,7 +84,7 @@ public class TanhDerivative extends DynamicCustomOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().div(sameDiff.onesLike(outputVariables()[0]), f().pow(f().cosh(arg()), 2)); + SDVariable ret = sameDiff.math.div(sameDiff.onesLike(outputVariables()[0]), sameDiff.math.pow(sameDiff.math.cosh(arg()), 2)); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java index 672159a3e..4069967c6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/AddOp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp; import java.util.List; @@ -34,14 +36,18 @@ public class AddOp extends BaseDynamicTransformOp { public AddOp() { } - public AddOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public AddOp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); } public AddOp(INDArray first, INDArray second, INDArray result){ this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); } + public AddOp(@NonNull INDArray x, @NonNull INDArray y) { + this(new INDArray[]{x,y}, null); + } + public AddOp(INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -63,7 +69,7 @@ public class AddOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().addBp(larg(), rarg(), i_v.get(0)); + return new AddBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java index b76942e95..2ce0101cf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/DivOp.java @@ -16,11 +16,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp; +import java.util.Arrays; import java.util.List; /** @@ -33,14 +36,18 @@ public class DivOp extends BaseDynamicTransformOp { public DivOp() {} - public DivOp( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public DivOp( @NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); } public DivOp(INDArray first, INDArray second, INDArray result){ this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); } + public DivOp( @NonNull INDArray x, INDArray y) { + this(new INDArray[]{x,y}, null); + } + public DivOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -65,7 +72,7 @@ public class DivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().divBp(larg(), rarg(), i_v.get(0)); + return Arrays.asList(new DivBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputVariables()); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java index c2314cdc4..408d86a75 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FModOp.java @@ -22,6 +22,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp; import java.util.List; @@ -83,6 +84,6 @@ public class FModOp extends BaseTransformSameOp { @Override public List doDiff(List f1) { - return f().floorModBp(larg(), rarg(), f1.get(0)); + return new FloorModBpOp(sameDiff, larg(), rarg(), f1.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorDivOp.java index debfc5a5d..7ed2c6c1c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorDivOp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp; import java.util.List; @@ -39,6 +41,10 @@ public class FloorDivOp extends BaseDynamicTransformOp { super(sameDiff, args, inPlace); } + public FloorDivOp(@NonNull INDArray x, @NonNull INDArray y) { + this(new INDArray[]{x, y}, null); + } + public FloorDivOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -63,6 +69,6 @@ public class FloorDivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().floorDivBp(larg(), rarg(), i_v.get(0)); + return new FloorDivBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorModOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorModOp.java index e7286816f..29799a221 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorModOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/FloorModOp.java @@ -16,12 +16,14 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp; import org.nd4j.linalg.api.shape.Shape; import java.util.Collections; @@ -39,6 +41,10 @@ public class FloorModOp extends BaseDynamicTransformOp { super(sameDiff, new SDVariable[]{x, y}, false); } + public FloorModOp(@NonNull INDArray x, @NonNull INDArray y) { + this(new INDArray[]{x, y}, null); + } + public FloorModOp(INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -60,7 +66,7 @@ public class FloorModOp extends BaseDynamicTransformOp { @Override public List doDiff(List f1) { - return f().floorModBp(larg(), rarg(), f1.get(0)); + return new FloorModBpOp(sameDiff, larg(), rarg(), f1.get(0)).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java index 51f2e449d..0d634766e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java @@ -84,7 +84,7 @@ public class MergeAddOp extends BaseDynamicTransformOp { public List calculateOutputDataTypes(List dataTypes){ DataType first = dataTypes.get(0); for( int i=1; i doDiff(List i_v) { - return f().modBp(larg(), rarg(), i_v.get(0)); + return new ModBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } - - } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java index 4636f9bc8..307a46557 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MulOp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp; import java.util.List; @@ -33,12 +35,16 @@ public class MulOp extends BaseDynamicTransformOp { public MulOp() {} - public MulOp( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public MulOp( @NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); + } + + public MulOp(INDArray first, INDArray second){ + this(first, second, null); } public MulOp(INDArray first, INDArray second, INDArray result){ - this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + this(new INDArray[]{first, second}, wrapOrNull(result)); } public MulOp( INDArray[] inputs, INDArray[] outputs) { @@ -66,7 +72,7 @@ public class MulOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().mulBp(larg(), rarg(), i_v.get(0)); + return new MulBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java index d54d91dbc..9891464fd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RDivOp.java @@ -16,12 +16,15 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp; +import java.util.Arrays; import java.util.List; /** @@ -34,14 +37,18 @@ public class RDivOp extends BaseDynamicTransformOp { public RDivOp() {} - public RDivOp( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public RDivOp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); } public RDivOp(INDArray first, INDArray second, INDArray result){ this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); } + public RDivOp(@NonNull INDArray x, @NonNull INDArray y){ + this(new INDArray[]{x, y}, null); + } + public RDivOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } @@ -64,6 +71,6 @@ public class RDivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().rdivBp(larg(), rarg(), i_v.get(0)); + return Arrays.asList(new RDivBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputVariables()); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java index 12c852949..5b233eb17 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp; import java.util.List; @@ -45,8 +46,16 @@ public class RSubOp extends BaseDynamicTransformOp { this(sameDiff, new SDVariable[]{i_v1, i_v2}, inPlace); } + public RSubOp(INDArray first, INDArray second){ + this(first, second, null); + } + public RSubOp(INDArray first, INDArray second, INDArray result){ - this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + this(new INDArray[]{first, second}, wrapOrNull(result)); + } + + public RSubOp( INDArray[] inputs, INDArray[] outputs) { + super(inputs, outputs); } public RSubOp() {} @@ -61,13 +70,9 @@ public class RSubOp extends BaseDynamicTransformOp { throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); } - public RSubOp( INDArray[] inputs, INDArray[] outputs) { - super(inputs, outputs); - } - @Override public List doDiff(List i_v) { - return f().rsubBp(larg(), rarg(), i_v.get(0)); + return new RSubBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RealDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RealDivOp.java index 04e72b5db..2fe1f150f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RealDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RealDivOp.java @@ -20,6 +20,7 @@ import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp; import java.util.List; @@ -60,7 +61,7 @@ public class RealDivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().divBp(larg(), rarg(), i_v.get(0)); + return new DivBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java index acbf840f1..e1ad183e7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SquaredDifferenceOp.java @@ -36,14 +36,21 @@ public class SquaredDifferenceOp extends BaseDynamicTransformOp { public SquaredDifferenceOp() {} - public SquaredDifferenceOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public SquaredDifferenceOp(SameDiff sameDiff, SDVariable x, SDVariable y, boolean inPlace) { + super(sameDiff, new SDVariable[]{x,y}, inPlace); } - public SquaredDifferenceOp(INDArray[] inputs, INDArray[] outputs) { - super(inputs, outputs); + public SquaredDifferenceOp(SameDiff sameDiff, SDVariable x, SDVariable y) { + this(sameDiff, x, y, false); } + public SquaredDifferenceOp(INDArray x, INDArray y, INDArray output) { + super(new INDArray[]{x,y}, new INDArray[]{output}); + } + + public SquaredDifferenceOp(INDArray x, INDArray y) { + addInputArgument(new INDArray[]{x,y}); + } @Override public String opName() { @@ -63,8 +70,7 @@ public class SquaredDifferenceOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v1) { - SDVariable[] outputs = new SquaredDifferenceBpOp(f().sameDiff(), new SDVariable[]{larg(), rarg(), i_v1.get(0)}).outputVariables(); - return Arrays.asList(outputs); + return new SquaredDifferenceBpOp(sameDiff, new SDVariable[]{larg(), rarg(), i_v1.get(0)}).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java index 0d222329e..da6e77a42 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/SubOp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp; import java.util.List; @@ -33,12 +35,16 @@ public class SubOp extends BaseDynamicTransformOp { public SubOp() {} - public SubOp( SameDiff sameDiff, SDVariable[] args, boolean inPlace) { - super(sameDiff, args, inPlace); + public SubOp( @NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y) { + super(sameDiff, new SDVariable[]{x, y}, false); + } + + public SubOp(INDArray first, INDArray second){ + this(first, second, null); } public SubOp(INDArray first, INDArray second, INDArray result){ - this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + this(new INDArray[]{first, second}, wrapOrNull(result)); } public SubOp( INDArray[] inputs, INDArray[] outputs) { @@ -65,7 +71,7 @@ public class SubOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - return f().subBp(larg(), rarg(), i_v.get(0)); + return new SubBpOp(sameDiff, larg(), rarg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/TruncateDivOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/TruncateDivOp.java index 973ecd7ba..4c99b479b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/TruncateDivOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/TruncateDivOp.java @@ -64,8 +64,8 @@ public class TruncateDivOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - SDVariable gradWrtX = f().div(i_v.get(0),rarg()); - SDVariable gradWrtY = f().mul(f().neg(gradWrtX),f().div(larg(),rarg())); + SDVariable gradWrtX = sameDiff.math.div(i_v.get(0),rarg()); + SDVariable gradWrtY = sameDiff.math.mul(sameDiff.math.neg(gradWrtX),sameDiff.math.div(larg(),rarg())); List ret = new ArrayList<>(2); ret.add(gradWrtX); ret.add(gradWrtY); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java index 95bd0bf41..d70de3b3e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/bool/Not.java @@ -69,6 +69,6 @@ public class Not extends BaseTransformBoolOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().zerosLike(arg())); + return Collections.singletonList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMax.java index e5ea01a60..d1e5917c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMax.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; import java.util.Collections; import java.util.List; @@ -57,7 +58,7 @@ public class AMax extends BaseTransformSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable minBp = new MinBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(minBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMin.java index bf6a37a55..b1ded6b55 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/AMin.java @@ -22,6 +22,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseReduceSameOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MinBp; import java.util.Collections; import java.util.List; @@ -57,7 +58,7 @@ public class AMin extends BaseTransformSameOp { @Override public List doDiff(List f1) { SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); + SDVariable minBp = new MinBp(sameDiff, sameDiff.math().abs(arg()), f1.get(0), false, dimensions).outputVariable(); return Collections.singletonList(sgn.mul(minBp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java index 4c6bf0ad9..ef21623c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Abs.java @@ -77,7 +77,7 @@ public class Abs extends BaseTransformSameOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().sign(arg()).mul(i_v.get(0)); + SDVariable ret = sameDiff.math.sign(arg()).mul(i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java index f2d163f3f..bc86ae999 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Ceil.java @@ -75,6 +75,6 @@ public class Ceil extends BaseTransformSameOp { public List doDiff(List f1) { //not continuously differentiable, but dOut/dIn = 0 in most places - return Arrays.asList(f().zerosLike(arg())); + return Arrays.asList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java index 6422e8df8..b4550cb4e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java @@ -25,6 +25,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; import java.util.Arrays; import java.util.List; @@ -77,6 +78,6 @@ public class Cube extends BaseTransformSameOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().cubeBp(arg(), f1.get(0))); + return new CubeBp(sameDiff, arg(), f1.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java index db682174c..f5aec6b48 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Max.java @@ -21,6 +21,8 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.reduce.bp.MaxBp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp; import java.util.Collections; import java.util.List; @@ -56,9 +58,7 @@ public class Max extends BaseTransformSameOp { @Override public List doDiff(List f1) { - SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); - return Collections.singletonList(sgn.mul(minBp)); + return new MaximumBp(sameDiff, larg(), rarg(), f1.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java index 6585ace19..1560a0e80 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Min.java @@ -21,7 +21,9 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -56,9 +58,10 @@ public class Min extends BaseTransformSameOp { @Override public List doDiff(List f1) { - SDVariable sgn = sameDiff.math().sign(arg()); - SDVariable minBp = f().minBp(sameDiff.math().abs(arg()), f1.get(0), false, dimensions); - return Collections.singletonList(sgn.mul(minBp)); + //TODO optimize + SDVariable gt = arg(0).gt(arg(1)).castTo(arg(0).dataType()); + SDVariable lt = arg(0).lt(arg(1)).castTo(arg(1).dataType()); + return Arrays.asList(lt.mul(f1.get(0)), gt.mul(f1.get(0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java index 37b370fe9..f03805eb6 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Negative.java @@ -73,7 +73,7 @@ public class Negative extends BaseTransformSameOp { @Override public List doDiff(List i_v) { - return Arrays.asList(f().neg(i_v.get(0))); + return Arrays.asList(sameDiff.math.neg(i_v.get(0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java index 1e11fa34d..8d2049f25 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Reciprocal.java @@ -74,7 +74,7 @@ public class Reciprocal extends BaseTransformSameOp { @Override public List doDiff(List i_v1) { // -1/(x^2) - SDVariable g = f().pow(arg(), 2).rdiv(-1).mul(i_v1.get(0)); + SDVariable g = sameDiff.math.pow(arg(), 2).rdiv(-1).mul(i_v1.get(0)); return Collections.singletonList(g); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java index 375a8acb5..de9e0b685 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Round.java @@ -75,6 +75,6 @@ public class Round extends BaseTransformSameOp { @Override public List doDiff(List f1) { - return Arrays.asList(f().zerosLike(arg())); + return Arrays.asList(sameDiff.zerosLike(arg())); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java index c63e00114..010955baf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Square.java @@ -21,8 +21,10 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformSameOp; +import org.nd4j.linalg.api.ops.impl.scalar.PowDerivative; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -72,7 +74,7 @@ public class Square extends BaseTransformSameOp { @Override public List doDiff(List i_v) { - SDVariable g = f().powDerivative(arg(), 2).mul(i_v.get(0)); - return Arrays.asList(g); + SDVariable g = new PowDerivative(sameDiff, arg(), false, 2).outputVariable().mul(i_v.get(0)); + return Collections.singletonList(g); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java index 1506ac5f3..aeb543ea8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java @@ -22,6 +22,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMaxBp; import java.util.*; @@ -59,7 +60,7 @@ public class UnsortedSegmentMax extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentMaxBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentMaxBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java index 4338cf33d..e869a84eb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMeanBp; import java.util.Arrays; import java.util.Collections; @@ -61,7 +62,7 @@ public class UnsortedSegmentMean extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentMeanBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentMeanBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java index 2f8aab0b1..fd0f5fd05 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentMinBp; import java.util.Arrays; import java.util.Collections; @@ -61,7 +62,7 @@ public class UnsortedSegmentMin extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentMinBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentMinBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java index 7afd75fac..12ec63222 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentProdBp; import java.util.Arrays; import java.util.Collections; @@ -61,7 +62,7 @@ public class UnsortedSegmentProd extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentProdBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentProdBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java index 77474855c..9d7aceb96 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp; import java.util.ArrayList; import java.util.Arrays; @@ -62,7 +63,7 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentSqrtNBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentSqrtNBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java index 336c756ac..5e5cfd12e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSum.java @@ -23,6 +23,7 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSumBp; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; @@ -62,7 +63,7 @@ public class UnsortedSegmentSum extends DynamicCustomOp { @Override public List doDiff(List gradients){ - return Arrays.asList(f().unsortedSegmentSumBp(arg(0), arg(1), gradients.get(0), numSegments)); + return new UnsortedSegmentSumBp(sameDiff, arg(0), arg(1), gradients.get(0), numSegments).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java index 3e0c60bb0..44dca1ed1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ACos.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,9 +76,9 @@ public class ACos extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //dacos(x)/dx = -1 / sqrt(1-x^2) - SDVariable oneSubSq = f().square(arg()).rsub(1.0); - SDVariable sqrt = f().sqrt(oneSubSq); + SDVariable oneSubSq = sameDiff.math.square(arg()).rsub(1.0); + SDVariable sqrt = sameDiff.math.sqrt(oneSubSq); SDVariable ret = sqrt.rdiv(-1.0).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java index 49ef2fb09..25f20e011 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ASinh.java @@ -75,8 +75,8 @@ public class ASinh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //dasinh(x)/dx = 1 / sqrt(x^2+1) - SDVariable xSqPlus1 = f().square(arg()).add(1.0); - SDVariable ret = i_v.get(0).div(f().sqrt(xSqPlus1)); + SDVariable xSqPlus1 = sameDiff.math.square(arg()).add(1.0); + SDVariable ret = i_v.get(0).div(sameDiff.math.sqrt(xSqPlus1)); return Arrays.asList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java index 483896dfd..a7a741759 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ATan.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -76,8 +77,8 @@ public class ATan extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //d(atan(x))/dx = 1/(x^2+1) - SDVariable xSqPlus1 = f().square(arg()).add(1.0); + SDVariable xSqPlus1 = sameDiff.math.square(arg()).add(1.0); SDVariable ret = xSqPlus1.rdiv(1.0).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java index 21076ad6e..35ed040c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cos.java @@ -64,7 +64,7 @@ public class Cos extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().neg(f().sin(arg())).mul(i_v.get(0)); + SDVariable ret = sameDiff.math.neg(sameDiff.math.sin(arg())).mul(i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java index dc08ead5f..5144315da 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Cosh.java @@ -74,7 +74,7 @@ public class Cosh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().sinh(arg()).mul(i_v.get(0)); + SDVariable ret = sameDiff.math.sinh(arg()).mul(i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java index c4fc245b7..cc3a5e116 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java @@ -23,6 +23,7 @@ import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; import java.util.Collections; import java.util.List; @@ -83,7 +84,7 @@ public class ELU extends DynamicCustomOp { public List doDiff(List i_v) { //ELU: e^x-1 if x<0, x otherwise //dL/dIn = dL/Out * dOut/dIn - return Collections.singletonList(f().eluBp(arg(), i_v.get(0), alpha)); + return new EluBp(sameDiff, arg(), i_v.get(0), alpha).outputs(); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java index 21aa49522..f9288d1d4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Exp.java @@ -73,7 +73,7 @@ public class Exp extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().mul(f().exp(arg()), i_v.get(0)); + SDVariable ret = sameDiff.math.mul(sameDiff.math.exp(arg()), i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java index 538f6a003..5b093a3ec 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Expm1.java @@ -75,7 +75,7 @@ public class Expm1 extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().mul(f().exp(arg()), i_v.get(0)); + SDVariable ret = sameDiff.math.mul(sameDiff.math.exp(arg()), i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java index b784ddde0..009492924 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/GELU.java @@ -68,7 +68,7 @@ public class GELU extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().geluDerivative(arg(), false).mul(i_v.get(0)); + SDVariable ret = new GELUDerivative(sameDiff, arg(), false).outputVariable().mul(i_v.get(0)); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java index ddaa8631f..91c4eb8ae 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative; import java.util.Collections; @@ -74,7 +75,7 @@ public class HardSigmoid extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().hardSigmoidBp(arg(), f1.get(0))); + return new HardSigmoidBp(sameDiff, arg(), f1.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java index fa80bf880..fc44f3c22 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp; import java.util.Arrays; import java.util.List; @@ -75,6 +76,6 @@ public class HardTanh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().hardTanhBp(arg(), i_v.get(0))); + return new HardTanhBp(sameDiff, arg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java index a937e1d63..1fd8ac430 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,8 +76,7 @@ public class Log extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - f().validateDifferentialFunctionsameDiff(arg()); - SDVariable toInverse = sameDiff.setupFunction(f().div(i_v.get(0), arg())); - return Arrays.asList(toInverse); + SDVariable toInverse = sameDiff.math.div(i_v.get(0), arg()); + return Collections.singletonList(toInverse); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java index 131986d15..d61504e39 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Log1p.java @@ -19,6 +19,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; import lombok.NoArgsConstructor; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.autodiff.util.SameDiffUtils; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; @@ -73,7 +74,7 @@ public class Log1p extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - f().validateDifferentialFunctionsameDiff(arg()); + SameDiffUtils.validateDifferentialFunctionSameDiff(sameDiff, arg(), this); return Collections.singletonList(i_v.get(0).div(arg().add(1.0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java index 353ced004..6a118a062 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/LogSigmoid.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import java.util.Arrays; import java.util.Collections; @@ -74,10 +75,8 @@ public class LogSigmoid extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { -// SDVariable ret = f().logSigmoidDerivative(arg(), i_v.get(0)); -// return Arrays.asList(ret); - SDVariable sigmDeriv = f().sigmoidDerivative(arg(), i_v.get(0)).div(f().sigmoid(arg())); - return Collections.singletonList(sigmDeriv); + SDVariable v = new SigmoidDerivative(sameDiff, arg(), i_v.get(0)).outputVariable().div(sameDiff.nn.sigmoid(arg())); + return Collections.singletonList(v); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java index 416f74133..05ca68aa0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Mish.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -69,8 +70,8 @@ public class Mish extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().mishDerivative(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + SDVariable ret = new MishDerivative(sameDiff, arg(), false).outputVariable().mul(i_v.get(0)); + return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java index ab565b30f..e2d8c8b5a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/PreciseGELU.java @@ -37,6 +37,10 @@ public class PreciseGELU extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public PreciseGELU(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false, true); + } + public PreciseGELU() { } @@ -72,7 +76,7 @@ public class PreciseGELU extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().geluDerivative(arg(), true).mul(i_v.get(0)); + SDVariable ret = new PreciseGELUDerivative(sameDiff, arg(), false, true).outputVariable().mul(i_v.get(0)); return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java index a05e34637..ecf85c9a9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java @@ -21,6 +21,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp; import java.util.Collections; import java.util.List; @@ -35,6 +36,10 @@ public class RationalTanh extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public RationalTanh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public RationalTanh() {} public RationalTanh(INDArray x, INDArray z) { @@ -68,6 +73,6 @@ public class RationalTanh extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().tanhRationalBp(arg(), f1.get(0))); + return new RationalTanhBp(sameDiff, arg(), f1.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java index d5fbf1294..8956f1b66 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java @@ -22,6 +22,7 @@ import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; @@ -42,6 +43,10 @@ public class RectifiedTanh extends BaseTransformStrictOp { super(sameDiff, i_v, inPlace); } + public RectifiedTanh(SameDiff sameDiff, SDVariable i_v) { + this(sameDiff, i_v, false); + } + public RectifiedTanh() {} public RectifiedTanh(INDArray x, INDArray z) { @@ -85,6 +90,6 @@ public class RectifiedTanh extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().tanhRectifiedBp(arg(), f1.get(0))); + return new RectifiedTanhBp(sameDiff, arg(), f1.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java index 00592f0e2..472c4eece 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java @@ -23,6 +23,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; import java.util.Arrays; import java.util.List; @@ -81,7 +82,7 @@ public class SELU extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().seluBp(arg(), i_v.get(0))); + return new SeluBp(sameDiff, arg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java index 37ef4b743..2452d5906 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sigmoid.java @@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import java.util.Arrays; import java.util.List; @@ -74,8 +75,7 @@ public class Sigmoid extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().sigmoidDerivative(arg(), i_v.get(0)); - return Arrays.asList(ret); + return new SigmoidDerivative(sameDiff, arg(), i_v.get(0)).outputs(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java index 0fa918c11..bfdde52d8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sin.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,8 +76,8 @@ public class Sin extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().cos(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + SDVariable ret = sameDiff.math.cos(arg()).mul(i_v.get(0)); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java index d5e3be988..2f5b981bd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Sinh.java @@ -25,6 +25,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -75,8 +76,8 @@ public class Sinh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().cosh(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + SDVariable ret = sameDiff.math.cosh(arg()).mul(i_v.get(0)); + return Collections.singletonList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java index 11ffb2ef8..f3eeda670 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftPlus.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -73,8 +74,8 @@ public class SoftPlus extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //dL/dIn = dL/Out * dOut/dIn - SDVariable ret = f().sigmoid(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + SDVariable ret = sameDiff.nn.sigmoid(arg()).mul(i_v.get(0)); + return Collections.singletonList(ret); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java index 8be5ea2d4..057fda972 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java @@ -16,15 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; -import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp; -import java.util.Arrays; import java.util.List; /** @@ -78,7 +75,7 @@ public class SoftSign extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - return Collections.singletonList(f().softsignBp(arg(), i_v.get(0))); + return new SoftSignBp(sameDiff, arg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java index 0794e0b57..7f694f481 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Swish.java @@ -74,7 +74,7 @@ public class Swish extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().swishDerivative(arg()).mul(i_v.get(0)); + SDVariable ret = new SwishDerivative(sameDiff, arg()).outputVariable().mul(i_v.get(0)); return Arrays.asList(ret); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SwishDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SwishDerivative.java index 350fb194e..552308859 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SwishDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SwishDerivative.java @@ -39,8 +39,8 @@ public class SwishDerivative extends BaseTransformStrictOp { super(sameDiff, i_v1, i_v2, inPlace); } - public SwishDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); + public SwishDerivative(SameDiff sameDiff, SDVariable i_v) { + super(sameDiff, i_v, false); } public SwishDerivative() {} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java index 3244925b1..2c9a603ee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tan.java @@ -76,7 +76,7 @@ public class Tan extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { //d(tan(x))/dx = (sec(x))^2 = 1 / (cos(x))^2 - SDVariable cosx = f().cos(arg()); + SDVariable cosx = sameDiff.math.cos(arg()); SDVariable cosSqx = sameDiff.math().square(cosx); return Collections.singletonList(i_v.get(0).div(cosSqx)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java index 136d0bbea..ca45a549a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/Tanh.java @@ -22,6 +22,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative; import java.util.Arrays; import java.util.List; @@ -74,7 +75,6 @@ public class Tanh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().tanhDerivative(arg(), i_v.get(0)); - return Arrays.asList(ret); + return new TanhDerivative(sameDiff, arg(), i_v.get(0)).outputs(); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java index 64e6a96b1..cd2a2b540 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java @@ -34,12 +34,11 @@ public class NDBase { /** * Boolean and array reduction operation, optionally along specified dimensions
* - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public INDArray all(INDArray x, int... dimensions) { - NDValidation.validateBool("all", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.All(x, dimensions)); } @@ -47,12 +46,11 @@ public class NDBase { /** * Boolean or array reduction operation, optionally along specified dimensions
* - * @param x Input variable (BOOL type) + * @param x Input variable (NDARRAY type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) * @return output reduced array of rank (input rank - num dimensions) (BOOL type) */ public INDArray any(INDArray x, int... dimensions) { - NDValidation.validateBool("any", "x", x); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.bool.Any(x, dimensions)); } @@ -114,6 +112,8 @@ public class NDBase { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param in Input variable (NUMERIC type) * @param keepDims If true: keep the dimensions that are reduced on (as size 1). False: remove the reduction dimensions @@ -138,6 +138,8 @@ public class NDBase { * keepDims = false: [a,c]
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param in Input variable (NUMERIC type) * @param dimensions Dimensions to reduce over. If dimensions are not specified, full array reduction is performed (Size: AtLeast(min=0)) @@ -369,6 +371,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -472,6 +476,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -504,6 +510,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -602,6 +610,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -634,6 +644,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -760,6 +772,8 @@ public class NDBase { * Element-wise maximum operation: out[i] = max(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) @@ -812,6 +826,21 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.floating.Mean(x, false, dimensions)); } + /** + * The merge operation is a control operation that forwards the either of the inputs to the output, when
+ * the first of them becomes available. If both are available, the output is undefined (either input could
+ * be forwarded to the output)
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output (NUMERIC type) + */ + public INDArray merge(INDArray x, INDArray y) { + NDValidation.validateNumerical("merge", "x", x); + NDValidation.validateNumerical("merge", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge(x, y))[0]; + } + /** * Minimum array reduction operation, optionally along specified dimensions. out = min(in)
* @@ -857,6 +886,8 @@ public class NDBase { * Element-wise minimum operation: out[i] = min(first[i], second[i])
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * @param first First input array (NUMERIC type) * @param second Second input array (NUMERIC type) @@ -919,6 +950,8 @@ public class NDBase { * If x and y arrays have equal shape, the output shape is the same as these inputs.
* * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
* * Return boolean array with values true where satisfied, or false otherwise.
* @@ -1978,6 +2011,18 @@ public class NDBase { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.same.Sum(x, false, dimensions)); } + /** + * Switch operation
+ * Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output
+ * + * @param x Input variable (NDARRAY type) + * @param predicate Predictate - if false, values are output to left (first) branch/output; if true, to right (second) branch/output (BOOL type) + */ + public INDArray[] switchOp(INDArray x, INDArray predicate) { + NDValidation.validateBool("switchOp", "predicate", predicate); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch(x, predicate)); + } + /** * //TODO: Ops must be documented.
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java index 03b9f8571..536633cd2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java @@ -138,7 +138,7 @@ public class NDImage { /** * Resize images to size using the specified method.
* - * @param input 4D image [NHWC] (NUMERIC type) + * @param input 4D image [NCHW] (NUMERIC type) * @param size new height and width (INT type) * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. * @param antialis Whether to use an anti-aliasing filter when downsampling an image @@ -161,7 +161,7 @@ public class NDImage { /** * Resize images to size using the specified method.
* - * @param input 4D image [NHWC] (NUMERIC type) + * @param input 4D image [NCHW] (NUMERIC type) * @param size new height and width (INT type) * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java index 8e8923834..1deddfd0a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -50,13 +50,13 @@ public class NDMath { * Looks up ids in a list of embedding tensors.
* * @param x Input tensor (NUMERIC type) - * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param indices A Tensor containing the ids to be looked up. (NUMERIC type) * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' * @return output Shifted output (NUMERIC type) */ public INDArray embeddingLookup(INDArray x, INDArray indices, PartitionMode PartitionMode) { NDValidation.validateNumerical("EmbeddingLookup", "x", x); - NDValidation.validateInteger("EmbeddingLookup", "indices", indices); + NDValidation.validateNumerical("EmbeddingLookup", "indices", indices); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(x, indices, PartitionMode))[0]; } @@ -93,6 +93,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh(x)); } + /** + * Pairwise addition operation, out = x + y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray add(INDArray x, INDArray y) { + NDValidation.validateNumerical("add", "x", x); + NDValidation.validateNumerical("add", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp(x, y))[0]; + } + + /** + * Scalar add operation, out = in + scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray add(INDArray x, double value) { + NDValidation.validateNumerical("add", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd(x, value)); + } + /** * Absolute max array reduction operation, optionally along specified dimensions: out = max(abs(x))
* @@ -540,6 +569,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.DiagPart(x))[0]; } + /** + * Pairwise division operation, out = x / y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray div(INDArray x, INDArray y) { + NDValidation.validateNumerical("div", "x", x); + NDValidation.validateNumerical("div", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp(x, y))[0]; + } + + /** + * Scalar division operation, out = in / scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray div(INDArray x, double value) { + NDValidation.validateNumerical("div", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarDivision(x, value)); + } + /** * Entropy reduction: -sum(x * log(x))
* @@ -739,6 +797,52 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Floor(x)); } + /** + * Pairwise floor division operation, out = floor(x / y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray floorDiv(INDArray x, INDArray y) { + NDValidation.validateNumerical("floorDiv", "x", x); + NDValidation.validateNumerical("floorDiv", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp(x, y))[0]; + } + + /** + * Pairwise Modulus division operation
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray floorMod(INDArray x, INDArray y) { + NDValidation.validateNumerical("floorMod", "x", x); + NDValidation.validateNumerical("floorMod", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp(x, y))[0]; + } + + /** + * Scalar floor modulus operation
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray floorMod(INDArray x, double value) { + NDValidation.validateNumerical("floorMod", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod(x, value)); + } + /** * Hamming distance reduction operation. The output contains the cosine distance for each
* tensor/subset along the specified dimensions:
@@ -1069,6 +1173,23 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse(in))[0]; } + /** + * Pairwise max operation, out = max(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public INDArray max(INDArray x, INDArray y) { + NDValidation.validateNumerical("max", "x", x); + NDValidation.validateNumerical("max", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(x, y))[0]; + } + /** * Merge add function: merges an arbitrary number of equal shaped arrays using element-wise addition:
* out = sum_i in[i]
@@ -1120,6 +1241,40 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.MeshGrid(inputs, cartesian)); } + /** + * Pairwise max operation, out = min(x, y)
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x First input variable, x (NUMERIC type) + * @param y Second input variable, y (NUMERIC type) + * @return out Output (NUMERIC type) + */ + public INDArray min(INDArray x, INDArray y) { + NDValidation.validateNumerical("min", "x", x); + NDValidation.validateNumerical("min", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Min(x, y))[0]; + } + + /** + * Pairwise modulus (remainder) operation, out = x % y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray mod(INDArray x, INDArray y) { + NDValidation.validateNumerical("mod", "x", x); + NDValidation.validateNumerical("mod", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.ModOp(x, y))[0]; + } + /** * Calculate the mean and (population) variance for the input variable, for the specified axis
* @@ -1132,6 +1287,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.reduce.Moments(input, axes)); } + /** + * Pairwise multiplication operation, out = x * y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray mul(INDArray x, INDArray y) { + NDValidation.validateNumerical("mul", "x", x); + NDValidation.validateNumerical("mul", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp(x, y))[0]; + } + + /** + * Scalar multiplication operation, out = in * scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray mul(INDArray x, double value) { + NDValidation.validateNumerical("mul", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication(x, value)); + } + /** * Elementwise negative operation: out = -x
* @@ -1200,6 +1384,48 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.Pow(x, y))[0]; } + /** + * Rational Tanh Approximation elementwise function, as described in the paper:
+ * Compact Convolutional Neural Network Cascade for Face Detection
+ * This is a faster Tanh approximation
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rationalTanh(INDArray x) { + NDValidation.validateNumerical("rationalTanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh(x)); + } + + /** + * Pairwise reverse division operation, out = y / x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rdiv(INDArray x, INDArray y) { + NDValidation.validateNumerical("rdiv", "x", x); + NDValidation.validateNumerical("rdiv", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp(x, y))[0]; + } + + /** + * Scalar reverse division operation, out = scalar / in
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray rdiv(INDArray x, double value) { + NDValidation.validateNumerical("rdiv", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseDivision(x, value)); + } + /** * Element-wise reciprocal (inverse) function: out[i] = 1 / in[i]
* @@ -1211,6 +1437,17 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Reciprocal(x)); } + /** + * Rectified tanh operation: max(0, tanh(in))
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rectifiedTanh(INDArray x) { + NDValidation.validateNumerical("rectifiedTanh", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.RectifiedTanh(x)); + } + /** * Element-wise round function: out = round(x).
* Rounds (up or down depending on value) to the nearest integer value.
@@ -1234,6 +1471,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt(x)); } + /** + * Pairwise reverse subtraction operation, out = y - x
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray rsub(INDArray x, INDArray y) { + NDValidation.validateNumerical("rsub", "x", x); + NDValidation.validateNumerical("rsub", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp(x, y))[0]; + } + + /** + * Scalar reverse subtraction operation, out = scalar - in
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray rsub(INDArray x, double value) { + NDValidation.validateNumerical("rsub", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction(x, value)); + } + /** * Set the diagonal value to the specified values
* If input is
@@ -1326,6 +1592,23 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.same.Square(x)); } + /** + * Pairwise squared difference operation.
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray squaredDifference(INDArray x, INDArray y) { + NDValidation.validateNumerical("squaredDifference", "x", x); + NDValidation.validateNumerical("squaredDifference", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp(x, y))[0]; + } + /** * Standardize input variable along given axis
*


@@ -1364,6 +1647,35 @@ public class NDMath { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.Step(x, value)); } + /** + * Pairwise subtraction operation, out = x - y
+ * + * Note: supports broadcasting if x and y have different shapes and are broadcastable.
+ * For example, if X has shape [1,10] and Y has shape [5,10] then op(X,Y) has output shape [5,10]
+ * Broadcast rules are the same as NumPy: https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ * + * @param x Input variable (NUMERIC type) + * @param y Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray sub(INDArray x, INDArray y) { + NDValidation.validateNumerical("sub", "x", x); + NDValidation.validateNumerical("sub", "y", y); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp(x, y))[0]; + } + + /** + * Scalar subtraction operation, out = in - scalar
+ * + * @param x Input variable (NUMERIC type) + * @param value Scalar value for op + * @return output Output variable (NUMERIC type) + */ + public INDArray sub(INDArray x, double value) { + NDValidation.validateNumerical("sub", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.scalar.ScalarSubtraction(x, value)); + } + /** * Elementwise tangent operation: out = tan(x)
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java index 06fb92b64..e2a8af245 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops; import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PadMode; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.Nd4j; @@ -355,6 +356,21 @@ public class NDNN { return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention(queries, keys, values, Wq, Wk, Wv, Wo, mask, scaled, false))[0]; } + /** + * Padding operation
+ * + * @param input Input tensor (NUMERIC type) + * @param padding Padding value (NUMERIC type) + * @param PadMode Padding format + * @param constant Padding constant + * @return output Padded input (NUMERIC type) + */ + public INDArray pad(INDArray input, INDArray padding, PadMode PadMode, double constant) { + NDValidation.validateNumerical("pad", "input", input); + NDValidation.validateNumerical("pad", "padding", padding); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode, constant))[0]; + } + /** * Padding operation
* @@ -366,7 +382,20 @@ public class NDNN { public INDArray pad(INDArray input, INDArray padding, double constant) { NDValidation.validateNumerical("pad", "input", input); NDValidation.validateNumerical("pad", "padding", padding); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, constant))[0]; + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.Pad(input, padding, PadMode.CONSTANT, constant))[0]; + } + + /** + * GELU activation function - Gaussian Error Linear Units
+ * For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
+ * This method uses the precise method
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray preciseGelu(INDArray x) { + NDValidation.validateNumerical("preciseGelu", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU(x)); } /** diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml index 6a3cc6eda..3e4367992 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml @@ -30,7 +30,6 @@ nd4j-tests-tensorflow - 1.8 1.8 @@ -216,8 +215,10 @@ **/*.java - org.nd4j.linalg.jcublas.JCublasBackend - org.nd4j.linalg.jcublas.JCublasBackend + org.nd4j.linalg.jcublas.JCublasBackend + + org.nd4j.linalg.jcublas.JCublasBackend + - + nd4j-backends org.nd4j @@ -29,7 +30,6 @@ nd4j-tests - 1.8 1.8 @@ -179,7 +179,8 @@ maven-surefire-plugin - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + src/test/java @@ -191,8 +192,10 @@ junit:junit - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + + + com.google.code.findbugs + * + + + diff --git a/nd4j/nd4j-common-tests/src/main/java/org/nd4j/AbstractAssertTestsClass.java b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/AbstractAssertTestsClass.java new file mode 100644 index 000000000..f229069ae --- /dev/null +++ b/nd4j/nd4j-common-tests/src/main/java/org/nd4j/AbstractAssertTestsClass.java @@ -0,0 +1,82 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j; + +import lombok.extern.slf4j.Slf4j; +import org.junit.Test; +import org.reflections.Reflections; +import org.reflections.scanners.MethodAnnotationsScanner; +import org.reflections.util.ClasspathHelper; +import org.reflections.util.ConfigurationBuilder; + +import java.lang.reflect.Method; +import java.util.*; + +import static org.junit.Assert.assertEquals; + +/** + * This class checks that all test classes (i.e., anything with one or more methods annotated with @Test) + * extends BaseDl4jTest - either directly or indirectly. + * Other than a small set of exceptions, all tests must extend this + * + * @author Alex Black + * @author Alexander Stoyakin + */ +@Slf4j +public abstract class AbstractAssertTestsClass extends BaseND4JTest { + + protected abstract Set> getExclusions(); + + protected abstract String getPackageName(); + + protected abstract Class getBaseClass(); + + @Override + public long getTimeoutMilliseconds() { + return 240000L; + } + + @Test + public void checkTestClasses(){ + + Reflections reflections = new Reflections(new ConfigurationBuilder() + .setUrls(ClasspathHelper.forPackage(getPackageName())) + .setScanners(new MethodAnnotationsScanner())); + Set methods = reflections.getMethodsAnnotatedWith(Test.class); + Set> s = new HashSet<>(); + for(Method m : methods){ + s.add(m.getDeclaringClass()); + } + + List> l = new ArrayList<>(s); + Collections.sort(l, new Comparator>() { + @Override + public int compare(Class aClass, Class t1) { + return aClass.getName().compareTo(t1.getName()); + } + }); + + int count = 0; + for(Class c : l){ + if(!getBaseClass().isAssignableFrom(c) && !getExclusions().contains(c)){ + log.error("Test {} does not extend {} (directly or indirectly). All tests must extend this class for proper memory tracking and timeouts", + c, getBaseClass()); + count++; + } + } + assertEquals("Number of tests not extending BaseND4JTest", 0, count); + } +} diff --git a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml index 0ece4c8b0..e640ed219 100644 --- a/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml +++ b/nd4j/nd4j-jdbc/nd4j-jdbc-mysql/pom.xml @@ -15,7 +15,8 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - + 4.0.0 @@ -34,9 +35,9 @@ ${project.version} - junit - junit - test + junit + junit + test diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml index 619af0d7b..3ba5a156a 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-rocksdb-storage/pom.xml @@ -42,9 +42,9 @@ - junit - junit - test + junit + junit + test diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml index 4ed7c8b7b..7c2783904 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-status/pom.xml @@ -49,9 +49,9 @@ - junit - junit - test + junit + junit + test diff --git a/nd4j/nd4j-remote/nd4j-json-server/pom.xml b/nd4j/nd4j-remote/nd4j-json-server/pom.xml index d7729f179..5537216ca 100644 --- a/nd4j/nd4j-remote/nd4j-json-server/pom.xml +++ b/nd4j/nd4j-remote/nd4j-json-server/pom.xml @@ -1,185 +1,185 @@ - 4.0.0 - jar + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> + 4.0.0 + jar - - org.nd4j - nd4j-remote - 1.0.0-SNAPSHOT - + + org.nd4j + nd4j-remote + 1.0.0-SNAPSHOT + - nd4j-json-server - nd4j-json-server + nd4j-json-server + nd4j-json-server - - UTF-8 - 1.7 - 1.7 - + + UTF-8 + 1.7 + 1.7 + - - - junit - junit - test - - - - org.nd4j - nd4j-json-client - ${project.version} - - - - org.slf4j - slf4j-api - - - - org.nd4j - nd4j-api - ${project.version} - - - - org.glassfish.jersey.core - jersey-client - ${jersey.version} - - - - org.glassfish.jersey.core - jersey-server - ${jersey.version} - - - - org.eclipse.jetty - jetty-server - 9.4.19.v20190610 - - - - org.eclipse.jetty - jetty-servlet - 9.4.19.v20190610 - - - - org.glassfish.jersey.inject - jersey-hk2 - ${jersey.version} - - - - org.glassfish.jersey.media - jersey-media-json-processing - ${jersey.version} - - - - org.glassfish.jersey.containers - jersey-container-servlet-core - ${jersey.version} - - - - ch.qos.logback - logback-core - ${logback.version} - test - - - - ch.qos.logback - logback-classic - ${logback.version} - test - - - - javax.xml.bind - jaxb-api - 2.3.0 - - - - com.sun.xml.bind - jaxb-impl - 2.3.0 - - - - com.sun.xml.bind - jaxb-core - 2.3.0 - - - - javax.activation - activation - 1.1 - - - - com.google.code.gson - gson - ${gson.version} - test - - - - - org.nd4j - nd4j-common-tests - ${project.version} - test - - - - - - - org.apache.maven.plugins - maven-compiler-plugin - - ${maven.compiler.source} - ${maven.compiler.target} - - - - - - - - nd4j-tests-cpu - + - org.nd4j - nd4j-native - ${project.version} - test + junit + junit + test - - - - nd4j-tests-cuda - - org.nd4j - nd4j-cuda-10.2 - ${project.version} - test + org.nd4j + nd4j-json-client + ${project.version} - - - - testresources - - + + org.slf4j + slf4j-api + + + + org.nd4j + nd4j-api + ${project.version} + + + + org.glassfish.jersey.core + jersey-client + ${jersey.version} + + + + org.glassfish.jersey.core + jersey-server + ${jersey.version} + + + + org.eclipse.jetty + jetty-server + 9.4.19.v20190610 + + + + org.eclipse.jetty + jetty-servlet + 9.4.19.v20190610 + + + + org.glassfish.jersey.inject + jersey-hk2 + ${jersey.version} + + + + org.glassfish.jersey.media + jersey-media-json-processing + ${jersey.version} + + + + org.glassfish.jersey.containers + jersey-container-servlet-core + ${jersey.version} + + + + ch.qos.logback + logback-core + ${logback.version} + test + + + + ch.qos.logback + logback-classic + ${logback.version} + test + + + + javax.xml.bind + jaxb-api + 2.3.0 + + + + com.sun.xml.bind + jaxb-impl + 2.3.0 + + + + com.sun.xml.bind + jaxb-core + 2.3.0 + + + + javax.activation + activation + 1.1 + + + + com.google.code.gson + gson + ${gson.version} + test + + + + + org.nd4j + nd4j-common-tests + ${project.version} + test + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + + ${maven.compiler.source} + ${maven.compiler.target} + + + + + + + + nd4j-tests-cpu + + + org.nd4j + nd4j-native + ${project.version} + test + + + + + + nd4j-tests-cuda + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} + test + + + + + + testresources + + diff --git a/nd4j/nd4j-serde/nd4j-aeron/pom.xml b/nd4j/nd4j-serde/nd4j-aeron/pom.xml index 827afb23a..c94bf86af 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/pom.xml +++ b/nd4j/nd4j-serde/nd4j-aeron/pom.xml @@ -16,179 +16,186 @@ - 4.0.0 + 4.0.0 - org.nd4j - nd4j-aeron - jar - - nd4j-aeron - - org.nd4j - nd4j-serde - 1.0.0-SNAPSHOT - - - 1.8 - 1.8 - 1.5.4 - 1.4.0 - UTF-8 - + nd4j-aeron + jar - - - jdk9 - - 1.9 - - - 8 - - - - testresources - + nd4j-aeron - - nd4j-tests-cpu - - false - - - - org.nd4j - nd4j-native - ${project.version} - - - - - - org.apache.maven.plugins - maven-surefire-plugin - - - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ - - src/test/java - - *.java - **/*.java - **/Test*.java - **/*Test.java - **/*TestCase.java - - junit:junit - - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend - - - -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g - - - - - + + + jdk9 + + 1.9 + + + 8 + + + + testresources + - - nd4j-tests-cuda - - false - - - - org.nd4j - nd4j-cuda-10.2 - ${project.version} - - - - - - org.apache.maven.plugins - maven-surefire-plugin + + nd4j-tests-cpu + + false + - - org.apache.maven.surefire - surefire-junit47 - 2.19.1 - + + org.nd4j + nd4j-native + ${project.version} + - - - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ - - src/test/java - - *.java - **/*.java - **/Test*.java - **/*Test.java - **/*TestCase.java - - junit:junit - - org.nd4j.linalg.jcublas.JCublasBackend - org.nd4j.linalg.jcublas.JCublasBackend - - - -Ddtype=float -Xmx6g - - - - - - + + + + org.apache.maven.plugins + maven-surefire-plugin + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + + + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g + + + + + + + + nd4j-tests-cuda + + false + + + + org.nd4j + nd4j-cuda-10.2 + ${project.version} + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + org.apache.maven.surefire + surefire-junit47 + 2.19.1 + + + + + + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cuda/blas/ + + + src/test/java + + *.java + **/*.java + **/Test*.java + **/*Test.java + **/*TestCase.java + + junit:junit + + org.nd4j.linalg.jcublas.JCublasBackend + + org.nd4j.linalg.jcublas.JCublasBackend + + + + -Ddtype=float -Xmx6g + + + + + + - - - org.nd4j - nd4j-api - ${project.version} - - - io.aeron - aeron-all - ${aeron.version} - - - junit - junit - test - + + + org.nd4j + nd4j-api + ${project.version} + + + io.aeron + aeron-all + ${aeron.version} + + + junit + junit + test + - - ch.qos.logback - logback-classic - ${logback.version} - test - + + ch.qos.logback + logback-classic + ${logback.version} + test + - - ch.qos.logback - logback-core - ${logback.version} - test - + + ch.qos.logback + logback-core + ${logback.version} + test + - - org.nd4j - nd4j-common-tests - ${project.version} - test - - + + org.nd4j + nd4j-common-tests + ${project.version} + test + + diff --git a/nd4j/nd4j-serde/nd4j-arrow/pom.xml b/nd4j/nd4j-serde/nd4j-arrow/pom.xml index 3a768c1a5..69879e965 100644 --- a/nd4j/nd4j-serde/nd4j-arrow/pom.xml +++ b/nd4j/nd4j-serde/nd4j-arrow/pom.xml @@ -88,7 +88,8 @@ maven-surefire-plugin - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + src/test/java @@ -100,8 +101,10 @@ junit:junit - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + - + nd4j-camel-routes org.nd4j diff --git a/nd4j/nd4j-serde/nd4j-gson/pom.xml b/nd4j/nd4j-serde/nd4j-gson/pom.xml index f488bfde5..60de01b6e 100644 --- a/nd4j/nd4j-serde/nd4j-gson/pom.xml +++ b/nd4j/nd4j-serde/nd4j-gson/pom.xml @@ -15,7 +15,8 @@ ~ SPDX-License-Identifier: Apache-2.0 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~--> - + nd4j-serde org.nd4j @@ -41,9 +42,9 @@ - junit - junit - test + junit + junit + test @@ -79,7 +80,8 @@ maven-surefire-plugin - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + src/test/java @@ -91,8 +93,10 @@ junit:junit - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + - + nd4j-serde org.nd4j @@ -101,17 +102,17 @@ ${spark.version} provided - - com.google.code.findbugs - jsr305 - + + com.google.code.findbugs + jsr305 + - junit - junit - test + junit + junit + test @@ -147,7 +148,8 @@ maven-surefire-plugin - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + src/test/java @@ -159,8 +161,10 @@ junit:junit - org.nd4j.linalg.cpu.nativecpu.CpuBackend - org.nd4j.linalg.cpu.nativecpu.CpuBackend + org.nd4j.linalg.cpu.nativecpu.CpuBackend + + org.nd4j.linalg.cpu.nativecpu.CpuBackend + com.github.oshi diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/BoundingBoxesDeserializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/BoundingBoxesDeserializer.java new file mode 100644 index 000000000..dec2c6c33 --- /dev/null +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/BoundingBoxesDeserializer.java @@ -0,0 +1,65 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.nn.conf.layers.objdetect; + +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer; +import org.nd4j.shade.jackson.core.JsonParser; +import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.databind.DeserializationContext; +import org.nd4j.shade.jackson.databind.JsonDeserializer; +import org.nd4j.shade.jackson.databind.JsonNode; + +import java.io.IOException; + +/** + * Custom deserializer to handle change in format between beta6 (and earlier) and later versions + * + * @author Alex Black + */ +public class BoundingBoxesDeserializer extends JsonDeserializer { + @Override + public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException, JsonProcessingException { + JsonNode node = jp.getCodec().readTree(jp); + if(node.has("dataBuffer")){ + //Must be legacy format serialization + JsonNode arr = node.get("dataBuffer"); + int rank = node.get("rankField").asInt(); + int numElements = node.get("numElements").asInt(); + int offset = node.get("offsetField").asInt(); + JsonNode shape = node.get("shapeField"); + JsonNode stride = node.get("strideField"); + int[] shapeArr = new int[rank]; + int[] strideArr = new int[rank]; + DataBuffer buff = Nd4j.createBuffer(numElements); + for (int i = 0; i < numElements; i++) { + buff.put(i, arr.get(i).asDouble()); + } + + String ordering = node.get("orderingField").asText(); + for (int i = 0; i < rank; i++) { + shapeArr[i] = shape.get(i).asInt(); + strideArr[i] = stride.get(i).asInt(); + } + + return Nd4j.create(buff, shapeArr, strideArr, offset, ordering.charAt(0)); + } + //Standard/new format + return new NDArrayTextDeSerializer().deserialize(node); + } +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java index 32f5627d5..24bda07f6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/layers/objdetect/Yolo2OutputLayer.java @@ -34,10 +34,9 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.impl.LossL2; +import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer; import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize; import org.nd4j.shade.jackson.databind.annotation.JsonSerialize; -import org.nd4j.shade.serde.jackson.VectorDeSerializer; -import org.nd4j.shade.serde.jackson.VectorSerializer; import java.util.Arrays; import java.util.Collection; @@ -77,8 +76,8 @@ public class Yolo2OutputLayer extends org.deeplearning4j.nn.conf.layers.Layer { private double lambdaNoObj; private ILossFunction lossPositionScale; private ILossFunction lossClassPredictions; - @JsonSerialize(using = VectorSerializer.class) - @JsonDeserialize(using = VectorDeSerializer.class) + @JsonSerialize(using = NDArrayTextSerializer.class) + @JsonDeserialize(using = BoundingBoxesDeserializer.class) private INDArray boundingBoxes; private Yolo2OutputLayer() { diff --git a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml index 282867e7e..44b868ae6 100644 --- a/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml +++ b/deeplearning4j/deeplearning4j-ui-parent/deeplearning4j-ui/pom.xml @@ -44,11 +44,6 @@ deeplearning4j-nlp ${project.version} - - org.nd4j - nd4j-jackson - ${nd4j.version} - junit diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextDeSerializer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextDeSerializer.java index d50143e77..11761a00d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextDeSerializer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/serde/jackson/shaded/NDArrayTextDeSerializer.java @@ -36,6 +36,10 @@ public class NDArrayTextDeSerializer extends JsonDeserializer { @Override public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { JsonNode n = jp.getCodec().readTree(jp); + return deserialize(n); + } + + public INDArray deserialize(JsonNode n){ //First: check for backward compatilibity (RowVectorSerializer/Deserializer) if(!n.has("dataType")){ diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml index 371386898..c181d4328 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml @@ -274,12 +274,6 @@ junit test - - org.nd4j - nd4j-jackson - ${project.version} - test - org.nd4j nd4j-api diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml index 8bfe85ea6..c621330d4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml @@ -76,12 +76,6 @@ ${project.version} - - org.nd4j - nd4j-jackson - ${project.version} - - org.nd4j diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/jackson/NdArraySerializerTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/jackson/NdArraySerializerTest.java deleted file mode 100644 index 910f535d0..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/jackson/NdArraySerializerTest.java +++ /dev/null @@ -1,72 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.serde.jackson; - -import org.junit.BeforeClass; -import org.junit.Test; -import org.nd4j.linalg.BaseNd4jTest; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; -import org.nd4j.shade.jackson.databind.ObjectMapper; -import org.nd4j.shade.jackson.databind.module.SimpleModule; -import org.nd4j.shade.serde.jackson.shaded.NDArrayDeSerializer; -import org.nd4j.shade.serde.jackson.shaded.NDArraySerializer; - -import static org.junit.Assert.assertEquals; - -/** - * Created by agibsonccc on 6/23/16. - */ -public class NdArraySerializerTest extends BaseNd4jTest { - private static ObjectMapper objectMapper; - - public NdArraySerializerTest(Nd4jBackend backend) { - super(backend); - } - - @Override - public char ordering() { - return 'c'; - } - - - @BeforeClass - public static void beforeClass() { - objectMapper = objectMapper(); - - } - - - @Test - public void testSerde() throws Exception { - String json = objectMapper.writeValueAsString(Nd4j.create(2, 2)); - INDArray assertion = Nd4j.create(2, 2); - INDArray test = objectMapper.readValue(json, INDArray.class); - assertEquals(assertion, test); - } - - private static ObjectMapper objectMapper() { - ObjectMapper mapper = new ObjectMapper(); - SimpleModule nd4j = new SimpleModule("nd4j"); - nd4j.addDeserializer(INDArray.class, new NDArrayDeSerializer()); - nd4j.addSerializer(INDArray.class, new NDArraySerializer()); - mapper.registerModule(nd4j); - return mapper; - - } -} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/jackson/VectorSerializeTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/jackson/VectorSerializeTest.java deleted file mode 100644 index f046cb97f..000000000 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/serde/jackson/VectorSerializeTest.java +++ /dev/null @@ -1,72 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.serde.jackson; - -import org.junit.BeforeClass; -import org.junit.Test; -import org.nd4j.linalg.BaseNd4jTest; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.factory.Nd4jBackend; -import org.nd4j.shade.jackson.databind.ObjectMapper; -import org.nd4j.shade.jackson.databind.module.SimpleModule; -import org.nd4j.shade.serde.jackson.VectorDeSerializer; -import org.nd4j.shade.serde.jackson.VectorSerializer; - -import static org.junit.Assert.assertEquals; - -/** - * Created by agibsonccc on 6/23/16. - */ -public class VectorSerializeTest extends BaseNd4jTest { - private static ObjectMapper objectMapper; - - public VectorSerializeTest(Nd4jBackend backend) { - super(backend); - } - - @Override - public char ordering() { - return 'c'; - } - - @BeforeClass - public static void beforeClass() { - objectMapper = objectMapper(); - } - - - - @Test - public void testSerde() throws Exception { - String json = objectMapper.writeValueAsString(Nd4j.create(2, 2)); - INDArray assertion = Nd4j.create(2, 2); - INDArray test = objectMapper.readValue(json, INDArray.class); - assertEquals(assertion, test); - } - - - private static ObjectMapper objectMapper() { - ObjectMapper mapper = new ObjectMapper(); - SimpleModule nd4j = new SimpleModule("nd4j"); - nd4j.addDeserializer(INDArray.class, new VectorDeSerializer()); - nd4j.addSerializer(INDArray.class, new VectorSerializer()); - mapper.registerModule(nd4j); - return mapper; - - } -} diff --git a/nd4j/nd4j-uberjar/pom.xml b/nd4j/nd4j-uberjar/pom.xml index 84f1c0d4a..bb0d45b5e 100644 --- a/nd4j/nd4j-uberjar/pom.xml +++ b/nd4j/nd4j-uberjar/pom.xml @@ -180,11 +180,6 @@ nd4j-aeron ${project.version} - - org.nd4j - nd4j-jackson - ${project.version} - org.nd4j nd4j-kryo_2.11