diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/ValidateCuDNN.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/ValidateCuDNN.java index 1d393aaf4..eed9f2efa 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/ValidateCuDNN.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/ValidateCuDNN.java @@ -248,7 +248,7 @@ public class ValidateCuDNN extends BaseDL4JTest { Nd4j.getRandom().setSeed(12345); INDArray features = Nd4j.rand(fShape); INDArray labels = Nd4j.rand(lShape); - labels = Nd4j.exec(new IsMax(labels, 1)); + labels = Nd4j.exec(new IsMax(labels, 1))[0].castTo(features.dataType()); List testCaseList = new ArrayList<>(); @@ -256,7 +256,7 @@ public class ValidateCuDNN extends BaseDL4JTest { for (int i = 0; i < 6; i++) { INDArray f = Nd4j.rand(fShape); INDArray l = Nd4j.rand(lShape); - Nd4j.exec(new IsMax(l, 1))[0]; + l = Nd4j.exec(new IsMax(l, 1))[0].castTo(features.dataType()); dataSets.add(new DataSet(f, l)); } DataSetIterator iter = new ExistingDataSetIterator(dataSets); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java index d12b9626c..2cae12e61 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/test/java/org/deeplearning4j/spark/text/TextPipelineTest.java @@ -25,7 +25,6 @@ import org.deeplearning4j.models.word2vec.Huffman; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction; -import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunctionAdapter; import org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction; import org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec; import org.deeplearning4j.spark.text.functions.CountCumSum; @@ -470,11 +469,11 @@ public class TextPipelineTest extends BaseSparkTest { Iterator, Long>> iterator = vocabWordListSentenceCumSumRDD.collect().iterator(); - FirstIterationFunctionAdapter firstIterationFunction = new FirstIterationFunctionAdapter( + FirstIterationFunction firstIterationFunction = new FirstIterationFunction( word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache()); - Iterable> ret = firstIterationFunction.call(iterator); - assertTrue(ret.iterator().hasNext()); + Iterator> ret = firstIterationFunction.call(iterator); + assertTrue(ret.hasNext()); } @Test diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml index 8bc852fbd..8b31872c5 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/pom.xml @@ -70,6 +70,13 @@ deeplearning4j-play_2.11 ${deeplearning4j.version} test + + + + net.jpountz.lz4 + lz4 + + diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index 3ef3716b3..3a57fc92b 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -273,9 +273,11 @@ namespace nd4j { * @param writeList * @param readList */ + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list static void registerSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList); static void prepareSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables = false); + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list static void registerPrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList); static void preparePrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables = false); diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index 7e74c3237..a29613b61 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -931,13 +931,13 @@ void initializeFunctions(Nd4jPointer *functions) { Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { Nd4jPointer pointer; // cudaHostAllocMapped |cudaHostAllocPortable - auto res = cudaHostAlloc(reinterpret_cast(&pointer), memorySize, cudaHostAllocDefault); + auto res = cudaHostAlloc(reinterpret_cast(&pointer), memorySize + 8, cudaHostAllocDefault); if (res != 0) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaHostAlloc failed"); } - return pointer; + return reinterpret_cast(pointer); } /** @@ -950,13 +950,13 @@ Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { */ Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { Nd4jPointer pointer; - auto res = cudaMalloc(reinterpret_cast(&pointer), memorySize); + auto res = cudaMalloc(reinterpret_cast(&pointer), memorySize + 8); if (res != 0) { nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMalloc failed"); } - return pointer; + return reinterpret_cast(pointer); } /** diff --git a/libnd4j/include/loops/cuda/transform/transform_bool.cu b/libnd4j/include/loops/cuda/transform/transform_bool.cu index a01221cfa..52e6b4a10 100644 --- a/libnd4j/include/loops/cuda/transform/transform_bool.cu +++ b/libnd4j/include/loops/cuda/transform/transform_bool.cu @@ -96,13 +96,13 @@ namespace functions { } else { if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); z[xOffset] = OpType::op(x[xOffset], params); } } else { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); z[zOffset] = OpType::op(x[xOffset], params); diff --git a/libnd4j/include/loops/cuda/transform/transform_same.cu b/libnd4j/include/loops/cuda/transform/transform_same.cu index a0d137d64..6c533ac3a 100644 --- a/libnd4j/include/loops/cuda/transform/transform_same.cu +++ b/libnd4j/include/loops/cuda/transform/transform_same.cu @@ -94,13 +94,13 @@ namespace functions { } else { if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); z[xOffset] = OpType::op(x[xOffset], params); } } else { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); z[zOffset] = OpType::op(x[xOffset], params); diff --git a/libnd4j/include/loops/cuda/transform/transform_strict.cu b/libnd4j/include/loops/cuda/transform/transform_strict.cu index 10385812d..a0989b0e6 100644 --- a/libnd4j/include/loops/cuda/transform/transform_strict.cu +++ b/libnd4j/include/loops/cuda/transform/transform_strict.cu @@ -96,13 +96,13 @@ namespace functions { } else { if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); z[xOffset] = OpType::op(x[xOffset], params); } } else { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); z[zOffset] = OpType::op(x[xOffset], params); diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index b0d891287..c298dde3a 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -116,7 +116,6 @@ #define TRANSFORM_STRICT_OPS \ - (3, ELUDerivative), \ (4, TanhDerivative), \ (5, HardTanhDerivative), \ (6, SigmoidDerivative), \ @@ -148,7 +147,6 @@ (32, ATan), \ (33, HardTanh), \ (34, SoftSign), \ - (35, ELU), \ (36, HardSigmoid), \ (37, RationalTanh) ,\ (38, RectifiedTanh) ,\ @@ -211,6 +209,8 @@ (4, ReverseDivide),\ (5, ReverseSubtract),\ (6, MaxPairwise),\ + (7, ELU), \ + (8, ELUDerivative), \ (13, MinPairwise),\ (14, CopyPws),\ (15, Mod),\ diff --git a/libnd4j/include/ops/declarable/generic/activations/elu.cpp b/libnd4j/include/ops/declarable/generic/activations/elu.cpp index 03c0cf834..03670ddab 100644 --- a/libnd4j/include/ops/declarable/generic/activations/elu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/elu.cpp @@ -25,12 +25,14 @@ #include namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(elu, 1, 1, true, 0, 0) { + CONFIGURABLE_OP_IMPL(elu, 1, 1, true, -2, 0) { + auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - input->applyTransform(nd4j::transform::ELU, output, nullptr); - STORE_RESULT(output); + const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; + + input->applyScalar(nd4j::scalar::ELU, alpha, output); return Status::OK(); } @@ -41,14 +43,18 @@ namespace nd4j { ->setAllowedOutputTypes(0, {ALL_FLOATS}); } - CONFIGURABLE_OP_IMPL(elu_bp, 2, 1, true, 0, 0) { + CONFIGURABLE_OP_IMPL(elu_bp, 2, 1, true, -2, 0) { + auto input = INPUT_VARIABLE(0); auto epsilon = INPUT_VARIABLE(1); - auto z = OUTPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f; + + // input->applyPairwiseTransform(pairwise::ELUDerivativeE, epsilon, output); + helpers::eluDerivative(block.launchContext(), input, epsilon, output, alpha); - //input->applyPairwiseTransform(pairwise::ELUDerivativeE, epsilon, z, nullptr); - helpers::eluDerivative(block.launchContext(), input, epsilon, z); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp b/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp index 68a460b56..ef65c4822 100644 --- a/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp +++ b/libnd4j/include/ops/declarable/generic/activations/lrelu.cpp @@ -25,15 +25,15 @@ #include namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(lrelu, 1, 1, true, 1, 0) { + CONFIGURABLE_OP_IMPL(lrelu, 1, 1, true, -2, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - float t = block.numT() > 0 ? T_ARG(0) : 0.0f; + float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; - input->applyScalar(nd4j::scalar::LeakyRELU, t, output); + input->applyScalar(nd4j::scalar::LeakyRELU, alpha, output); STORE_RESULT(output); - + return Status::OK(); } @@ -42,15 +42,17 @@ namespace nd4j { ->setAllowedInputTypes(0, DataType::ANY) ->setAllowedOutputTypes(0, {ALL_FLOATS}); } - - CONFIGURABLE_OP_IMPL(lrelu_bp, 2, 1, true, 0, 0) { + + CONFIGURABLE_OP_IMPL(lrelu_bp, 2, 1, true, -2, 0) { auto input = INPUT_VARIABLE(0); auto epsilon = INPUT_VARIABLE(1); auto z = OUTPUT_VARIABLE(0); + float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f; + //input->applyPairwiseTransform(pairwise::LRELUDerivativeE, epsilon, z, nullptr); - helpers::leakyReluDerivative(block.launchContext(), input, epsilon, z); + helpers::leakyReluDerivative(block.launchContext(), input, epsilon, z, alpha); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/headers/activations.h b/libnd4j/include/ops/declarable/headers/activations.h index ecc55351a..9d0b22198 100644 --- a/libnd4j/include/ops/declarable/headers/activations.h +++ b/libnd4j/include/ops/declarable/headers/activations.h @@ -82,8 +82,8 @@ namespace nd4j { * Math is: x < 0 ? alpha * x : x; */ #if NOT_EXCLUDED(OP_lrelu) - DECLARE_CONFIGURABLE_OP(lrelu, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(lrelu_bp, 2, 1, true, 0, 0); + DECLARE_CONFIGURABLE_OP(lrelu, 1, 1, true, -2, 0); + DECLARE_CONFIGURABLE_OP(lrelu_bp, 2, 1, true, -2, 0); #endif /** @@ -91,8 +91,8 @@ namespace nd4j { * Math is: x >= 0 ? x : exp(x) - 1; */ #if NOT_EXCLUDED(OP_elu) - DECLARE_CONFIGURABLE_OP(elu, 1, 1, true, 0, 0); - DECLARE_CONFIGURABLE_OP(elu_bp, 2, 1, true, 0, 0); + DECLARE_CONFIGURABLE_OP(elu, 1, 1, true, -2, 0); + DECLARE_CONFIGURABLE_OP(elu_bp, 2, 1, true, -2, 0); #endif /** @@ -157,7 +157,7 @@ namespace nd4j { /** * This is Concatenated RELU implementation. * What happens inside: RELU(Concat((x, -x, {-1}))) - * + * * PLEASE NOTE: Concatenation will double amount of features available in input */ #if NOT_EXCLUDED(OP_crelu) diff --git a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp index d673e64bd..09cb2df2e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/legacy_helper.cpp @@ -81,29 +81,35 @@ namespace helpers { } template - static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x >= (T)0.f? y : T(0.f); + static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) { + + const T alphaT = static_cast(alpha); + + auto functor = LAMBDA_TT(x, y, alphaT) { + return x < 0 ? alphaT * y : y; }; input->applyPairwiseLambda(epsilon, functor, output); } - void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); + void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); } template - static void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * nd4j::math::nd4j_eluderivative(x); + static void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) { + + const T alphaT = static_cast(alpha); + + auto functor = LAMBDA_TT(x, y, alphaT){ + return y * nd4j::math::nd4j_eluderivative(x, alphaT); }; input->applyPairwiseLambda(epsilon, functor, output); } - void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); + void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); } template diff --git a/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu b/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu index 54f518ad9..545d7c668 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/compare_elem.cu @@ -30,7 +30,7 @@ namespace nd4j { shared[threadIdx.x] = 0; - + // each thread will compare 2 elements: E and E+1 for (int e = tid; e < length - 1; e += blockDim.x * gridDim.x) { auto val0 = x[shape::getIndexOffset(e, xShapeInfo, length)]; auto val1 = x[shape::getIndexOffset(e+1, xShapeInfo, length)]; @@ -41,11 +41,12 @@ namespace nd4j { else v = val1 >= val0; + // store comparison result in shared memory shared[threadIdx.x] += v ? 0 : 1; } __syncthreads(); - // aggregate sum + // aggregate sums in shared memory for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { if (threadIdx.x < activeThreads) shared[threadIdx.x] += shared[threadIdx.x + activeThreads]; @@ -53,7 +54,7 @@ namespace nd4j { } - // store over the grid + // store over the grid if we have more than 1 block if (gridDim.x > 1) { auto tc = reinterpret_cast(reductionBuffer); @@ -96,7 +97,7 @@ namespace nd4j { } } else { - + // if we have only 1 block, we just store results right away if (threadIdx.x == 0) { auto tc = reinterpret_cast(reductionBuffer); tc[16384] = 0; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu index 98ab86dec..87e7c4f08 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/convolutions.cu @@ -424,7 +424,7 @@ static __global__ void avgPooling2dCuda(const void *vx, const Nd4jLong *xShapeIn } __syncthreads(); - int tid = blockIdx.x * gridDim.x + threadIdx.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; for (int index = tid; index < length; index += blockDim.x * gridDim.x) { @@ -519,7 +519,7 @@ static __global__ void pnormPooling2dCuda(const void *vx, const Nd4jLong *xShape } __syncthreads(); - int tid = blockIdx.x * gridDim.x + threadIdx.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; for (int index = tid; index < length; index += blockDim.x * gridDim.x) { @@ -610,7 +610,7 @@ static __global__ void maxPooling2dCuda(const void *vx, const Nd4jLong *xShapeIn } __syncthreads(); - int tid = blockIdx.x * gridDim.x + threadIdx.x; + int tid = blockIdx.x * blockDim.x + threadIdx.x; for (int index = tid; index < length; index += blockDim.x * gridDim.x) { @@ -908,7 +908,7 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf /*** max ***/ case 0: { coord2 = hstart; - coord3 = hend; + coord3 = wstart; T max = -DataTypeUtils::max(); for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { @@ -923,8 +923,9 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf } coords[2] = coord2; coords[3] = coord3; - nd4j::math::atomics::nd4j_atomicAdd(&z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank)], y[yOffset]); - + auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank); + nd4j::math::atomics::nd4j_atomicAdd(&z[zOffset], y[yOffset]); + //z[zOffset] += y[yOffset]; } break; @@ -987,7 +988,7 @@ void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& i PointersManager manager(block.launchContext(), "pooling2dBP"); - const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int threadsPerBlock = 256; const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/diag.cu b/libnd4j/include/ops/declarable/helpers/cuda/diag.cu index f4dff2279..0e861b866 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/diag.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/diag.cu @@ -39,7 +39,7 @@ static __global__ void diagFunctorKernel(void* outputBuffer, Nd4jLong* outputSha } __syncthreads(); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (int t = tid; t < inputLength; t += step) { z[shape::getIndexOffset(t * (inputLength + 1), outputShape, outputLength)] = x[shape::getIndexOffset(t, inputShape, inputLength)]; //tX]; @@ -59,7 +59,7 @@ static __global__ void diagFunctorKernel(void* outputBuffer, Nd4jLong* outputSha } __syncthreads(); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; Nd4jLong i = threadIdx.x * (outputLength + 1); for (int t = tid; t < outputLength && i < inputLength; t += step) { diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu index 40251fb82..5b4c27bd0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dropout.cu @@ -35,9 +35,11 @@ namespace helpers { T const* input = reinterpret_cast(inputBuf); T* output = reinterpret_cast(outputBuf); + // trivial idea: loop through all elements, get independent probability for each element to be nullified for (Nd4jLong e = 0; e < inLen; ++e) { T val = nodeRng->relativeT(e, T(0.f), T(1.f)); + // if probability is ok - we're saving scaled value if (double(val) < probVal) output[shape::getIndexOffset(e, outputShape, inLen)] = T(input[shape::getIndexOffset(e, inputShape, inLen)] / probVal); } @@ -80,7 +82,7 @@ namespace helpers { std::vector dims(reduceShape->lengthOf()); reduceShape->syncToHost(); // to ensure that follows are actual bool fit = true; -// PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit)) + for( int i = 0; i < dims.size(); i++ ) { if (fit) { dims[i] = reduceShape->e(i); @@ -96,8 +98,7 @@ namespace helpers { REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); std::unique_ptr chunk(new NDArray('c', dims, output->dataType(), context.launchContext())); chunk->assign(1.f); - //chunk->applyRandom>(rng, nullptr, chunk.get(), &probValue); - //NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob); + dropoutSimple(context.launchContext(), chunk.get(), chunk.get(), probValue, seed); // broadcast chunk to full matrix std::unique_ptr dropOutMultiplier(new NDArray(*input)); @@ -105,6 +106,7 @@ namespace helpers { *dropOutMultiplier += *chunk; + // FIXME: we could do this in one step, aren't we? output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); } @@ -113,8 +115,11 @@ namespace helpers { int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { auto xType = input->dataType(); + NDArray::prepareSpecialUse({output}, {input}); BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES); + + NDArray::registerSpecialUse({output}, {input}); } /////////////////////////////////// backrpopagations /////////////////////////////////////////////// @@ -136,6 +141,8 @@ namespace helpers { for (int e = tid; e < len; e += step) { const auto zOffset = shape::getIndexOffset(e, outputShape, len); + + // if probability was non-zero on FF step, we'll scale grads back if (output[zOffset] != T(0.)) output[zOffset] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue); @@ -143,12 +150,17 @@ namespace helpers { } template static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { + // we're making additional FF run to see how probabilities played out with given seeds int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue); auto stream = context.launchContext()->getCudaStream(); + NDArray::prepareSpecialUse({output}, {input, gradOut}); + if (ND4J_STATUS_OK == res) dropoutBPKernel<<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue); + NDArray::registerSpecialUse({output}, {input, gradOut}); + return res; } @@ -239,6 +251,7 @@ namespace helpers { int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta); if (res == ND4J_STATUS_OK) { + // FIXME: can we make it single-loop? (*output) *= alpha; (*output) *= (*gradOut); //->applyPairwiseTransform(gradOut, output, nullptr); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu index 92e5b38b4..7d520478e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/dynamic.cu @@ -43,7 +43,7 @@ namespace nd4j { } __syncthreads(); - + // we run things in blocks, 1 partition per block of threads for (Nd4jLong o = blockIdx.x; o < numOutputs; o += gridDim.x) { auto z = reinterpret_cast(vz[o]); @@ -89,9 +89,11 @@ namespace nd4j { auto x = reinterpret_cast(vx); auto indices = reinterpret_cast(vindices); + // we run things in blocks, 1 partition per block of threads for (int i = blockIdx.x; i < numOutputs; i += gridDim.x) { auto z = reinterpret_cast(vz[i]); + // each thread has own counter for partitions int outCnt = 0; for (Nd4jLong e = 0; e < iLength; e++) { @@ -145,6 +147,7 @@ namespace nd4j { tadOffsets[i] = packZ.platformOffsets(); } + // we copy pointers to device auto dOutBuffers = reinterpret_cast(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); auto dOutTadShapes = reinterpret_cast(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *))); auto dOutTadOffsets = reinterpret_cast(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *))); @@ -248,6 +251,7 @@ namespace nd4j { indicesShapes[e] = indices.at(e)->getSpecialShapeInfo(); } + // copying pointers to buffers to device auto dInputBuffers = reinterpret_cast(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); auto dIndicesBuffers = reinterpret_cast(pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *))); auto dInputShapes = reinterpret_cast(pm.replicatePointer(inputShapes.data(), inputSize * sizeof(Nd4jLong *))); @@ -283,6 +287,7 @@ namespace nd4j { inputTadOffsets[e] = packX.platformOffsets(); } + // copying pointers to buffers to device auto dInputBuffers = reinterpret_cast(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); auto dInputTadShapes = reinterpret_cast(pm.replicatePointer(inputTadShapes.data(), inputSize * sizeof(Nd4jLong *))); auto dInputTadOffsets = reinterpret_cast(pm.replicatePointer(inputTadOffsets.data(), inputSize * sizeof(Nd4jLong *))); @@ -313,6 +318,7 @@ namespace nd4j { NDArray::registerSpecialUse({}, {indices, input}); + // TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list for (auto v:outputList) { v->tickWriteDevice(); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu b/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu index 9baab5b36..6a818a2cd 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/flatten.cu @@ -29,6 +29,7 @@ namespace nd4j { Nd4jLong xCoord[MAX_RANK]; + // each block of threads works on 1 input array for (Nd4jLong e = blockIdx.x; e < numInputs; e += gridDim.x) { auto z = reinterpret_cast(zBuffer) + offsets[e]; @@ -39,6 +40,7 @@ namespace nd4j { auto xRank = shape::rank(xShapeInfo); auto xLength = shape::length(xShapeInfo); + // each element of this input array has own place within common output array for (uint i = threadIdx.x; i < xLength; i += blockDim.x) { shape::index2coords(xRank, xShape, i, xLength, xCoord, order); auto xOffset = shape::getOffset(0, xShape, xStride, xCoord, xRank); @@ -65,6 +67,7 @@ namespace nd4j { hdShapes[e] = inputs[e]->specialShapeInfo(); } + // copying pointers to device auto dBuffers = (void **) pm.replicatePointer(hdBuffers.data(), inputs.size() * sizeof(void*)); auto dShapes = (Nd4jLong **)pm.replicatePointer(hdShapes.data(), inputs.size() * sizeof(Nd4jLong*)); auto dOffsets = (Nd4jLong *) pm.replicatePointer(hOffsets.data(), inputs.size() * sizeof(Nd4jLong)); @@ -76,6 +79,7 @@ namespace nd4j { } void flatten(nd4j::LaunchContext *context, std::vector &inputs, NDArray *output, char order) { + // FIXME: we want NDArrayFactory::prepareSpecialUse here eventually for (auto v:inputs) v->syncToDevice(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu b/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu index 86f5b4d5a..9d0e5e55b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/gradient.cu @@ -26,6 +26,7 @@ namespace ops { namespace helpers { template void applyGradientDescent_(LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { + // classic one auto lambda = LAMBDA_TT(_x, _y, weight) { return _x - (_y * weight); }; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu index 52b059dad..a4bcbb311 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu @@ -44,6 +44,7 @@ namespace nd4j { X binSize = X((*max_val - *min_val) / numBins); + // nullify bins for (int e = threadIdx.x; e < numBins; e += blockDim.x) { bins[e] = (Z) 0; } @@ -53,14 +54,12 @@ namespace nd4j { int idx = int((dx[e] - *min_val) / binSize); idx = math::nd4j_max(idx, 0); //atomicMax(&idx, 0);//atomicMax(&idx, 0); idx = math::nd4j_min(idx, int(numBins - 1)); //atomicMin(&idx, int(numBins - 1)); - nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1); -// bins[idx]++; + nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1); } __syncthreads(); + // at this point all bins in shared memory are calculated, so we aggregate them now via threadfence trick // transfer shared memory to reduction memory - - if (gridDim.x > 1) { unsigned int *tc = (unsigned int *)reductionPointer; __shared__ bool amLast; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu index bc7ea1caa..a5d686dc2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu @@ -64,6 +64,7 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), copy.data(), copy.size()); + // we launch legacy IndexMax op, to get indices of max values along dimension auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions); dim3 launchDims(256, 256, 16384); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu index a0f30a116..c2dd4919d 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/legacy/relu.cu @@ -66,29 +66,35 @@ namespace nd4j { } template - linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return x >= (T)0.f? y : T(0.f); + linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) { + + const T alphaT = static_cast(alpha); + + auto functor = LAMBDA_TT(x, y, alphaT) { + return x < 0 ? alphaT * y : y; }; input->applyPairwiseLambda(epsilon, functor, output); } - void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); + void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); } template - linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { - auto functor = LAMBDA_TT(x, y){ - return y * nd4j::math::nd4j_eluderivative(x); + linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) { + + const T alphaT = static_cast(alpha); + + auto functor = LAMBDA_TT(x, y, alphaT){ + return y * nd4j::math::nd4j_eluderivative(x, alphaT); }; input->applyPairwiseLambda(epsilon, functor, output); } - void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { - BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); + void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) { + BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES); } template diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu b/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu index 239d280dc..ca2bda4a4 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/lrn.cu @@ -41,12 +41,12 @@ namespace helpers { const T tbeta = static_cast(beta); const T talpha = static_cast(alpha); - + // one block of threads processes 1 example within batch for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { auto x = reinterpret_cast(vx) + xTadOffsets[i]; auto z = reinterpret_cast(vz) + zTadOffsets[i]; - // load everything into shared memory + // load everything into shared memory, so we'll operate on shared memory from now on shared[threadIdx.x] = x[threadIdx.x * xEws]; __syncthreads(); @@ -94,7 +94,7 @@ namespace helpers { sharedY[threadIdx.x] = 0.f; __syncthreads(); - + // we're operating in shared memory for (int s = begin; s < end; s++) sharedY[threadIdx.x] = sharedY[threadIdx.x] + sharedX[s] * sharedX[s]; __syncthreads(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu index 082472fce..27c8fc630 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu @@ -37,7 +37,7 @@ namespace nd4j { static __global__ void global_mergeMaxIndex_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) { auto output = reinterpret_cast(voutput); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (Nd4jLong e = tid; e < length; e += step) { @@ -81,7 +81,13 @@ namespace nd4j { } void mergeMaxIndex(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, {}); + for (auto v:inArrs) + v->syncToDevice(); + BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INDEXING_TYPES); + + NDArray::registerSpecialUse({&output}, {}); } @@ -90,7 +96,7 @@ namespace nd4j { static __global__ void global_mergeMax_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) { auto output = reinterpret_cast(voutput); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (Nd4jLong e = tid; e < length; e += step) { @@ -131,7 +137,12 @@ namespace nd4j { } void mergeMax(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, {}); + for (auto v:inArrs) + v->syncToDevice(); + BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES); + NDArray::registerSpecialUse({&output}, {}); } ////////////////////////////////////////////////////////////////////////// @@ -139,7 +150,7 @@ namespace nd4j { static __global__ void global_mergeAvg_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) { auto output = reinterpret_cast(voutput); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (Nd4jLong e = tid; e < length; e += step) { @@ -178,7 +189,13 @@ namespace nd4j { } void mergeAvg(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, {}); + for (auto v:inArrs) + v->syncToDevice(); + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), FLOAT_TYPES); + + NDArray::registerSpecialUse({&output}, {}); } ////////////////////////////////////////////////////////////////////////// @@ -186,7 +203,7 @@ namespace nd4j { static __global__ void global_mergeAdd_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) { auto output = reinterpret_cast(voutput); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (Nd4jLong e = tid; e < length; e += step) { @@ -226,7 +243,13 @@ namespace nd4j { BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output), NUMERIC_TYPES); void mergeAdd(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output) { + NDArray::prepareSpecialUse({&output}, {}); + for (auto v:inArrs) + v->syncToDevice(); + BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), NUMERIC_TYPES); + + NDArray::registerSpecialUse({&output}, {}); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu b/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu index aeddd3b97..3b80f3df9 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/nth_element.cu @@ -31,18 +31,18 @@ namespace helpers { template static __global__ void fillUpElementKernel(void* outputBuffer, Nd4jLong* outputShapeInfo, void* inputBuffer, Nd4jLong* inputShapeInfo, Nd4jLong* pTadShape, Nd4jLong* pTadOffsets, Nd4jLong n) { - __shared__ T *z, *x; __shared__ Nd4jLong bufferLength, arrLen; + auto z = reinterpret_cast(outputBuffer); + auto x = reinterpret_cast(inputBuffer); + if (threadIdx.x == 0) { - z = reinterpret_cast(outputBuffer); - x = reinterpret_cast(inputBuffer); arrLen = shape::length(pTadShape); bufferLength = shape::length(outputShapeInfo); } __syncthreads(); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (int t = tid; t < bufferLength; t += step) { auto tX = x + pTadOffsets[t]; @@ -77,8 +77,6 @@ namespace helpers { // manager.synchronize(); sortedVals.tickWriteDevice(); sortedVals.syncToHost(); - sortedVals.printIndexedBuffer("Hello"); - sortedVals.printBuffer("Hello line"); auto stream = context->getCudaStream(); fillUpElementKernel<<<32, 64, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), pTadShape, pTadOffsets, n); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu index 94d3c02ea..bddaf65e3 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/polyGamma.cu @@ -74,17 +74,14 @@ static void polyGammaCudaLauncher(const int blocksPerGrid, const int threadsPerB /////////////////////////////////////////////////////////////////// void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& z) { - if(!n.isActualOnDeviceSide()) n.syncToDevice(); - if(!x.isActualOnDeviceSide()) x.syncToDevice(); + NDArray::prepareSpecialUse({&z}, {&n, &x}); int threadsPerBlock = MAX_NUM_THREADS; int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); - n.tickReadHost(); - x.tickReadHost(); - z.tickWriteDevice(); + NDArray::registerSpecialUse({&z}, {&n, &x}); } BUILD_SINGLE_TEMPLATE(template void polyGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vn, const Nd4jLong *nShapeInfo, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/range.cu b/libnd4j/include/ops/declarable/helpers/cuda/range.cu index 7e8ddb2a7..3a5504905 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/range.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/range.cu @@ -28,7 +28,7 @@ namespace helpers { template static __global__ void global_range(void *output, Nd4jLong length, T start, T delta) { auto buff = reinterpret_cast(output); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for(Nd4jLong i = tid; i < length; i += step) @@ -43,10 +43,11 @@ namespace helpers { } void range(nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) { + NDArray::prepareSpecialUse({&outVector}, {&start, &delta}); BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (context, start, delta, outVector), LIBND4J_TYPES); + NDArray::registerSpecialUse({&outVector}, {&start, &delta}); } - BUILD_SINGLE_TEMPLATE(template void _range, (nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector), NUMERIC_TYPES); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu index f90c9f77f..8c67cbf1b 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu @@ -26,13 +26,11 @@ namespace nd4j { namespace helpers { template void toggle_bits__(NDArray &in, NDArray &out) { - NDArray::prepareSpecialUse({&out}, {&in}); auto lambda = LAMBDA_T(_x) { return ~_x;//eUtils::flip_bits(_x); }; in.applyLambda(lambda, &out); - NDArray::registerSpecialUse({&out}, {&in}); } BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index 7e35ec819..c3e4f497e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -906,7 +906,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1; } __syncthreads(); - const auto tid = blockIdx.x * gridDim.x + threadIdx.x; + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto step = gridDim.x * blockDim.x; for (Nd4jLong e = tid; e < length; e += step) { diff --git a/libnd4j/include/ops/declarable/helpers/legacy_helpers.h b/libnd4j/include/ops/declarable/helpers/legacy_helpers.h index 476c743ea..dfe338864 100644 --- a/libnd4j/include/ops/declarable/helpers/legacy_helpers.h +++ b/libnd4j/include/ops/declarable/helpers/legacy_helpers.h @@ -46,8 +46,8 @@ namespace helpers { void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); - void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); + void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha); + void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha); void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index e4fef2c3c..a80e274ca 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -2271,26 +2271,26 @@ namespace simdOps { } }; - template + template class ELU { public: no_op_exec_special_same no_op_exec_special_same_cuda - op_def static X op(X d1, X *params) { - return nd4j::math::nd4j_elu(d1); + op_def static Z op(X d1, Y d2, Z *params) { + return nd4j::math::nd4j_elu(d1, static_cast(d2)); } }; - template + template class ELUDerivative { public: no_op_exec_special_same no_op_exec_special_same_cuda - op_def static X op(X d1, X *params) { - return nd4j::math::nd4j_eluderivative(d1); + op_def static Z op(X d1, Y d2, Z *params) { + return nd4j::math::nd4j_eluderivative(d1, static_cast(d2)); } }; @@ -3716,7 +3716,7 @@ namespace simdOps { return reduction; } - op_def static Z op(X d1, X d2, Z *extraParamsRef) { + op_def static Z op(X d1, X d2, Z *extraParamsRef) { double eps = nd4j::math::nd4j_abs(extraParamsRef[2]); return static_cast(!nd4j::math::nd4j_eq(d1, d2, eps)); } @@ -4540,4 +4540,4 @@ namespace simdOps { } #endif - + diff --git a/libnd4j/include/templatemath.h b/libnd4j/include/templatemath.h index 6a543b35d..68b59f2d4 100644 --- a/libnd4j/include/templatemath.h +++ b/libnd4j/include/templatemath.h @@ -130,13 +130,12 @@ namespace nd4j { } template - math_def inline Z nd4j_elu(T val) { - if (val >= (T) 0.f) return val; - else return nd4j_exp(val) - (Z) 1.0f; - //return val >= 0.0 ? val : (nd4j_exp(val) - 1.0); + math_def inline Z nd4j_elu(T val, T alpha) { + if (val >= (T) 0.f) + return val; + return static_cast(alpha) * (nd4j_exp(val) - static_cast(1.0f)); } - template math_def inline Z nd4j_leakyrelu(T val,T alpha) { if (val < (T) 0.0f) @@ -145,13 +144,14 @@ namespace nd4j { return val; } - template - math_def inline Z nd4j_eluderivative(T val) { - if (val >= (T) 0.0f) return (Z) 1.0f; - else return nd4j_exp(val); + math_def inline Z nd4j_eluderivative(T val, T alpha) { + if (val >= static_cast(0.0f)) + return static_cast(1.0f); + return static_cast(alpha) * nd4j_exp(val); //return val >= 0.0 ? 1.0 : nd4j_exp(val); } + template math_def inline Z nd4j_sin(T val); @@ -283,7 +283,7 @@ namespace nd4j { #ifdef NATIVE_HALFS if (value < (float16) 0.f) { return float16(__hneg(value.data)); - } else + } else return value; #else return (float16) fabsf((float) value); @@ -904,13 +904,13 @@ inline __device__ int16_t nd4j_atomicMax(int16_t* address, int16_t val) template <> inline __device__ float16 nd4j_atomicMax(float16* address, float16 val) { - int* address_as_ull = (int*) address; + auto address_as_ull = (int*) address; long addr = (long) address; bool misaligned = addr & 0x3; if (misaligned) - address_as_ull = (int *) (addr - 2); + address_as_ull = (int *) (address - 1); PAIR old, assumed, fresh; @@ -937,13 +937,13 @@ inline __device__ float16 nd4j_atomicMax(float16* address, float16 val) template <> inline __device__ bfloat16 nd4j_atomicMax(bfloat16* address, bfloat16 val) { - int* address_as_ull = (int*) address; + auto address_as_ull = (int*) address; long addr = (long)(address); bool misaligned = addr & 0x3; if (misaligned) - address_as_ull = (int *) (addr - 2); + address_as_ull = (int *) (address - 1); BPAIR old, assumed, fresh; @@ -1060,13 +1060,13 @@ inline __device__ float16 nd4j_atomicAdd(float16* address, float16 val) #if __CUDA_ARCH__ >= 700 atomicAdd(reinterpret_cast<__half*>(address), val.data); #else - int* address_as_ull = (int*) address; + auto address_as_ull = (int*) address; long addr = (long) address; bool misaligned = addr & 0x3; if (misaligned) - address_as_ull = (int *) (addr - 2); + address_as_ull = (int *) (address - 1); PAIR old, assumed, fresh; @@ -1094,13 +1094,13 @@ inline __device__ float16 nd4j_atomicAdd(float16* address, float16 val) template <> inline __device__ bfloat16 nd4j_atomicAdd(bfloat16* address, bfloat16 val) { - int* address_as_ull = (int*) address; + auto address_as_ull = (int*) address; - long addr = (long)(address); + auto addr = (long)(address); bool misaligned = addr & 0x3; if (misaligned) - address_as_ull = (int *) (addr - 2); + address_as_ull = (int *) (address - 1); BPAIR old, assumed, fresh; @@ -1367,13 +1367,13 @@ inline __device__ Nd4jLong nd4j_atomicMul(Nd4jLong* address, Nd4jLong template <> inline __device__ bfloat16 nd4j_atomicMul(bfloat16* address, bfloat16 val) { - int* address_as_ull = (int*) address; + auto address_as_ull = (int*) address; long addr = (long)(address); bool misaligned = addr & 0x3; if (misaligned) - address_as_ull = (int *) (addr - 2); + address_as_ull = (int *) (address - 1); BPAIR old, assumed, fresh; @@ -1400,13 +1400,13 @@ inline __device__ bfloat16 nd4j_atomicMul(bfloat16* address, bfloat16 template <> inline __device__ float16 nd4j_atomicMul(float16* address, float16 val) { - int* address_as_ull = (int*) address; + auto address_as_ull = (int*) address; long addr = (long)(address); bool misaligned = addr & 0x3; if (misaligned) - address_as_ull = (int *) (addr - 2); + address_as_ull = (int *) (address - 1); BPAIR old, assumed, fresh; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 7fbc309d5..3f868c45c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -905,6 +905,25 @@ TEST_F(DeclarableOpsTests12, softmax_9) { delete arrF; } +TEST_F(DeclarableOpsTests12, maxpool_bp_half_1) { + auto x = NDArrayFactory::create('c', {2, 3, 10, 1}, {0.2019043f, 0.6464844f, 0.9116211f, 0.60058594f, 0.34033203f, 0.7036133f, 0.6772461f, 0.3815918f, 0.87353516f, 0.04650879f, 0.67822266f, 0.8618164f, 0.88378906f, 0.7573242f, 0.66796875f, 0.63427734f, 0.33764648f, 0.46923828f, 0.62939453f, 0.76464844f, -0.8618164f, -0.94873047f, -0.9902344f, -0.88916016f, -0.86572266f, -0.92089844f, -0.90722656f, -0.96533203f, -0.97509766f, -0.4975586f, -0.84814453f, -0.984375f, -0.98828125f, -0.95458984f, -0.9472656f, -0.91064453f, -0.80859375f, -0.83496094f, -0.9140625f, -0.82470703f, 0.4802246f, 0.45361328f, 0.28125f, 0.28320312f, 0.79345703f, 0.44604492f, -0.30273438f, 0.11730957f, 0.56396484f, 0.73583984f, 0.1418457f, -0.44848633f, 0.6923828f, -0.40234375f, 0.40185547f, 0.48632812f, 0.14538574f, 0.4638672f, 0.13000488f, 0.5058594f}); + auto y = NDArrayFactory::create('c', {2, 3, 10, 1}, {0.0f, -0.13391113f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, -0.1751709f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.51904297f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5107422f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + auto z = NDArrayFactory::create('c', {2, 3, 10, 1}); + + nd4j::ops::maxpool2d_bp op; + Context ctx(1); + Nd4jLong iArgs[] = {5,1,1, 2,2,0, 1,1,1, 0,0}; + ctx.setIArguments(iArgs, 11); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + +} + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, lrn_bp_1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index 3b4ff6cd0..1ec9650f9 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -2794,53 +2794,42 @@ TEST_F(DeclarableOpsTests3, svd_test11) { TEST_F(DeclarableOpsTests3, elu_test1) { auto x = NDArrayFactory::create('c', {3,3}, {0.1, .2, .3, -.4,-.5,-.6, .7, .8, .9}); -// auto expS = NDArrayFactory::create('c', {3}); -// auto expU = NDArrayFactory::create('c', {3,3}); - auto exp = NDArrayFactory::create('c', {3,3}, {.1, .2, .3, -0.32968, -0.393469, -0.451188, .7, .8, .9}); + auto exp = NDArrayFactory::create('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9}); nd4j::ops::elu op; - auto results = op.execute({&x}, {}, {}); + auto results = op.execute({&x}, {0.5}, {}); ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto s = results->at(0); -// auto u = results->at(1); -// auto v = results->at(2); -// s->printIndexedBuffer("ELU"); ASSERT_TRUE(exp.equalsTo(s)); delete results; } /////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests3, elu_test2) { +TEST_F(DeclarableOpsTests3, elu_bp_test1) { - auto x = NDArrayFactory::create('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9}); - auto eps = NDArrayFactory::create('c', {3,3}); - eps.assign(2.); -// auto expU = NDArrayFactory::create('c', {3,3}); - auto exp = NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 1.34064, 1.213061, 1.097623, 2, 2, 2}); + auto x = NDArrayFactory::create('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9}); + auto eps = NDArrayFactory::create('c', {3,3}); + eps.assign(2.); + auto exp = NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2}); - nd4j::ops::elu_bp op; - auto results = op.execute({ &x, &eps }, {}, {}); + nd4j::ops::elu_bp op; + auto results = op.execute({ &x, &eps }, {0.5}, {}); - ASSERT_EQ(ND4J_STATUS_OK, results->status()); + ASSERT_EQ(ND4J_STATUS_OK, results->status()); - auto s = results->at(0); -// auto u = results->at(1); -// auto v = results->at(2); -// s->printIndexedBuffer("ELU_BP"); - ASSERT_TRUE(exp.equalsTo(s)); + auto s = results->at(0); + ASSERT_TRUE(exp.equalsTo(s)); - delete results; + delete results; } /////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests3, lrelu_test1) { auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); -// auto expS = NDArrayFactory::create('c', {3}); -// auto expU = NDArrayFactory::create('c', {3,3}); auto exp = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9}); nd4j::ops::lrelu op; @@ -2849,20 +2838,16 @@ TEST_F(DeclarableOpsTests3, lrelu_test1) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto s = results->at(0); -// auto u = results->at(1); -// auto v = results->at(2); -// s->printIndexedBuffer("LRELU"); ASSERT_TRUE(exp.equalsTo(s)); delete results; } -TEST_F(DeclarableOpsTests3, lrelu_test2) { +TEST_F(DeclarableOpsTests3, lrelu_bp_test1) { auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); -// auto expS = NDArrayFactory::create('c', {3}); auto eps = NDArrayFactory::create('c', {3,3}, {2,2,2,2,2,2,2, 2,2}); - auto exp = NDArrayFactory::create('c', {3,3}, {2, 2, 2, 0, 0, 0, 2, 2, 2}); + auto exp = NDArrayFactory::create('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2}); nd4j::ops::lrelu_bp op; auto results = op.execute({&x, &eps}, {0.2}, {}); @@ -2870,9 +2855,6 @@ TEST_F(DeclarableOpsTests3, lrelu_test2) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto s = results->at(0); -// auto u = results->at(1); -// auto v = results->at(2); -// s->printIndexedBuffer("LRELU_BP"); ASSERT_TRUE(exp.equalsTo(s)); delete results; @@ -2882,8 +2864,6 @@ TEST_F(DeclarableOpsTests3, lrelu_test2) { TEST_F(DeclarableOpsTests3, selu_test1) { auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); -// auto expS = NDArrayFactory::create('c', {3}); -// auto expU = NDArrayFactory::create('c', {3,3}); auto exp = NDArrayFactory::create('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309}); nd4j::ops::selu op; @@ -2892,7 +2872,6 @@ TEST_F(DeclarableOpsTests3, selu_test1) { ASSERT_EQ(ND4J_STATUS_OK, results->status()); auto s = results->at(0); -// s->printIndexedBuffer("SELU"); ASSERT_TRUE(exp.equalsTo(s)); delete results; diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 3af53bad0..86acca29c 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -2761,7 +2761,7 @@ TEST_F(DeclarableOpsTests5, ELU_1) { auto exp = NDArrayFactory::create('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.}); auto res = NDArrayFactory::create('c', {2, 2, 2}); - input.applyTransform(transform::ELU, &res); + input.applyScalar(nd4j::scalar::ELU, 1.f, &res); ASSERT_TRUE(res.equalsTo(&exp)); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 7ffeca762..49e760961 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -204,20 +204,29 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum; import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast; import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And; @@ -1126,10 +1135,26 @@ public class DifferentialFunctionFactory { return new org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative(sameDiff(), iX, wrt).outputVariable(); } + public SDVariable tanhRationalBp(SDVariable in, SDVariable epsilon) { + return new RationalTanhBp(sameDiff(), in, epsilon).outputVariable(); + } + + public SDVariable tanhRectifiedBp(SDVariable in, SDVariable epsilon) { + return new RectifiedTanhBp(sameDiff(), in, epsilon).outputVariable(); + } + + /** + * Use {@link #tanhRationalBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable tanhRationalDerivative(SDVariable in) { return new RationalTanhDerivative(sameDiff(), in, false).outputVariable(); } + /** + * Use {@link #tanhRectifiedBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable tanhRectifiedDerivative(SDVariable in) { return new RectifiedTanhDerivative(sameDiff(), in, false).outputVariable(); } @@ -1280,6 +1305,14 @@ public class DifferentialFunctionFactory { return new Cube(sameDiff(), iX, false).outputVariable(); } + public SDVariable cubeBp(SDVariable in, SDVariable epsilon) { + return new CubeBp(sameDiff(), in, epsilon).outputVariable(); + } + + /** + * @deprecated Use {@link #cubeBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable cubeDerivative(SDVariable iX) { return new CubeDerivative(sameDiff(), iX, false).outputVariable(); } @@ -1329,6 +1362,14 @@ public class DifferentialFunctionFactory { return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable(); } + public SDVariable thresholdRelu(SDVariable in, SDVariable epsilon, double cutoff){ + return new ThresholdRelu(sameDiff(), in, cutoff).outputVariable(); + } + + public SDVariable thresholdReluBp(SDVariable in, SDVariable epsilon, double cutoff){ + return new ThresholdReluBp(sameDiff(), in, epsilon, cutoff).outputVariable(); + } + public SDVariable relu6(SDVariable iX, double cutoff) { return new Relu6(sameDiff(), iX, false, cutoff).outputVariable(); } @@ -1350,6 +1391,14 @@ public class DifferentialFunctionFactory { return new HardTanh(sameDiff(), iX, false).outputVariable(); } + public SDVariable hardTanhBp(SDVariable in, SDVariable epsilon) { + return new HardTanhBp(sameDiff(), in, epsilon).outputVariable(); + } + + /** + * @deprecated Use {@link #hardTanhBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable hardTanhDerivative(SDVariable iX) { return new HardTanhDerivative(sameDiff(), iX, false).outputVariable(); } @@ -1358,6 +1407,9 @@ public class DifferentialFunctionFactory { return new HardSigmoid(sameDiff(), in, false).outputVariable(); } + public SDVariable hardSigmoidBp(SDVariable in, SDVariable epsilon){ + return new HardSigmoidBp(sameDiff(), in, epsilon).outputVariable(); + } public SDVariable sigmoid(SDVariable iX) { return new Sigmoid(sameDiff(), iX, false).outputVariable(); @@ -1486,10 +1538,16 @@ public class DifferentialFunctionFactory { } + public SDVariable softsignBp(SDVariable in, SDVariable epsilon) { + return new SoftSignBp(sameDiff(), in, epsilon).outputVariable(); + } + /** + * @deprecated Use {@link #softsignBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable softsignDerivative(SDVariable iX) { return new SoftSignDerivative(sameDiff(), iX, false).outputVariable(); - } @@ -1500,14 +1558,12 @@ public class DifferentialFunctionFactory { public SDVariable elu(SDVariable iX) { - return new ELU(sameDiff(), iX, false).outputVariable(); + return new ELU(sameDiff(), iX).outputVariable(); } - - public SDVariable eluDerivative(SDVariable iX) { - return new ELUDerivative(sameDiff(), iX, false).outputVariable(); - + public SDVariable eluBp(SDVariable in, SDVariable epsilon) { + return new EluBp(sameDiff(), in, epsilon).outputVariable(); } @@ -1516,6 +1572,14 @@ public class DifferentialFunctionFactory { } + public SDVariable leakyReluBp(SDVariable in, SDVariable epsilon, double cutoff) { + return new LeakyReLUBp(sameDiff(), in, epsilon, cutoff).outputVariable(); + } + + /** + * @deprecated Use {@link #leakyReluBp(SDVariable, SDVariable, double)} + */ + @Deprecated public SDVariable leakyReluDerivative(SDVariable iX, double cutoff) { return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable(); } @@ -1832,7 +1896,15 @@ public class DifferentialFunctionFactory { return new SELU(sameDiff(), arg, false).outputVariable(); } + public SDVariable seluBp(SDVariable in, SDVariable epsilon) { + validateDifferentialFunctionsameDiff(in); + return new SeluBp(sameDiff(), in, epsilon).outputVariable(); + } + /** + * @deprecated Use {@link #seluBp(SDVariable, SDVariable)} + */ + @Deprecated public SDVariable seluDerivative(SDVariable arg) { validateDifferentialFunctionsameDiff(arg); return new SELUDerivative(sameDiff(), arg, false).outputVariable(); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index eb89a0f3a..cd9d7ffd2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -163,31 +163,6 @@ public class SDNN extends SDOps { return updateVariableNameAndReference(result, name); } - /** - * Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input. - * {@link #elu(SDVariable)} - * - * @param x Input variable - * @return Output variable - */ - public SDVariable eluDerivative(SDVariable x) { - return eluDerivative(null, x); - } - - /** - * Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input. - * {@link #elu(SDVariable)} - * - * @param name Output variable name - * @param x Input variable - * @return Output variable - */ - public SDVariable eluDerivative(String name, SDVariable x) { - validateFloatingPoint("eluDerivative", x); - SDVariable result = f().eluDerivative(x); - return updateVariableNameAndReference(result, name); - } - /** * GELU activation function - Gaussian Error Linear Units
* For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415 diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java index c69295cc6..eeb6b1b78 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java @@ -255,8 +255,6 @@ public class LegacyOpMapper { return Abs.class; case 2: return LogSoftMax.class; - case 3: - return ELUDerivative.class; case 4: return org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class; case 5: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 5bc175952..541b0a545 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -881,7 +881,6 @@ public class OpValidation { SoftmaxBp.class, CubeDerivative.class, - ELUDerivative.class, GELUDerivative.class, PreciseGELUDerivative.class, HardSigmoidDerivative.class, @@ -901,6 +900,17 @@ public class OpValidation { TanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class, PowDerivative.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class, BiasAddGrad.class, ConcatBp.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index da580b748..5bfba7a48 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -229,6 +229,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class, org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class, org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class, org.nd4j.linalg.api.ops.impl.scalar.Relu6.class, org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class, org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class, @@ -421,7 +422,6 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp.class, - org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative.class, @@ -433,6 +433,17 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class, + org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationCube.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationCube.java index 79fbcea5a..766f4a4db 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationCube.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationCube.java @@ -21,6 +21,8 @@ import lombok.Getter; import lombok.NonNull; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp; import org.nd4j.linalg.api.ops.impl.transforms.same.Cube; import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative; import org.nd4j.linalg.factory.Nd4j; @@ -42,9 +44,9 @@ public class ActivationCube extends BaseActivationFunction { @Override public Pair backprop(@NonNull INDArray in, @NonNull INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new CubeDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + Nd4j.getExecutioner().execAndReturn(new CubeBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java index 48e118d06..b7ac3887c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java @@ -18,11 +18,11 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.conditions.Conditions; @@ -57,7 +57,7 @@ public class ActivationELU extends BaseActivationFunction { public INDArray getActivation(INDArray in, boolean training) { // no support in ELU native to override alpha if (this.alpha != 1.00) { - INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup())); + INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0]; alphaMultiple.muli(alpha); BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0)); } else { @@ -74,21 +74,8 @@ public class ActivationELU extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - // no support in ELU native to override alpha - if (alpha != 1.00) { - INDArray dLdz = Nd4j.getExecutioner().exec(new ELUDerivative(in.dup())); - dLdz.muli(alpha); - BooleanIndexing.replaceWhere(dLdz, 1, Conditions.equals(alpha)); - - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); - } - - else { - INDArray dLdz = Nd4j.getExecutioner().exec(new ELUDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); - } + Nd4j.getExecutioner().execAndReturn(new EluBp(in, epsilon, in)); + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardSigmoid.java index b2b73be0e..4076b40e3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardSigmoid.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,9 +43,9 @@ public class ActivationHardSigmoid extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + Nd4j.getExecutioner().execAndReturn(new HardSigmoidBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java index 7cf80ffc3..f8c405d38 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationHardTanH.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -43,9 +45,10 @@ public class ActivationHardTanH extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new HardTanhDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new HardTanhBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java index f59a7ddb0..864f16901 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationLReLU.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; @@ -54,9 +56,10 @@ public class ActivationLReLU extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new LeakyReLUDerivative(in, alpha)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new LeakyReLUBp(in, epsilon, in, alpha)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRReLU.java index f2e01508f..8d5bb3ddd 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRReLU.java @@ -63,7 +63,7 @@ public class ActivationRReLU extends BaseActivationFunction { public INDArray getActivation(INDArray in, boolean training) { if (training) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { - this.alpha = Nd4j.rand(in.shape(), l, u, Nd4j.getRandom()); + this.alpha = Nd4j.rand(l, u, Nd4j.getRandom(), in.shape()); } INDArray inTimesAlpha = in.mul(alpha); BooleanIndexing.replaceWhere(in, inTimesAlpha, Conditions.lessThan(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRationalTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRationalTanh.java index 84c8878bc..0e6cc2a51 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRationalTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRationalTanh.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -48,9 +50,10 @@ public class ActivationRationalTanh extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new RationalTanhDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new RationalTanhBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU6.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU6.java index f7bc24966..611f2c9ee 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU6.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationReLU6.java @@ -20,8 +20,10 @@ import lombok.EqualsAndHashCode; import lombok.Getter; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.scalar.Relu6; import org.nd4j.linalg.api.ops.impl.scalar.Step; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; @@ -41,9 +43,10 @@ public class ActivationReLU6 extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new Step(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new Relu6Derivative(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRectifiedTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRectifiedTanh.java index 58ff1bc7b..ccd4cafe2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRectifiedTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationRectifiedTanh.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -45,9 +47,10 @@ public class ActivationRectifiedTanh extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new RectifiedTanhDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new RectifiedTanhBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSELU.java index 773a2578a..3eed5ac9c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSELU.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,9 +43,10 @@ public class ActivationSELU extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new SELUDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new SeluBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSigmoid.java index aa7e7a1c6..72500e677 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSigmoid.java @@ -18,7 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; -import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,9 +42,10 @@ public class ActivationSigmoid extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new SigmoidDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new SigmoidDerivative(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftPlus.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftPlus.java index fa5fe3ef8..0eb7781f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftPlus.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftPlus.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp; import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; @@ -41,9 +43,10 @@ public class ActivationSoftPlus extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new Sigmoid(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new SoftPlusBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftSign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftSign.java index ff3b298e4..3857ff084 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftSign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftSign.java @@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; @@ -41,9 +43,10 @@ public class ActivationSoftSign extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new SoftSignDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new SoftSignBp(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java index 2fd8b439a..095c2548a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationSoftmax.java @@ -21,7 +21,9 @@ import lombok.Getter; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; @@ -42,10 +44,10 @@ public class ActivationSoftmax extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray out = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, in.ulike()))[0]; - INDArray x = out.mul(epsilon).sum(1); - INDArray dLdz = out.mul(epsilon.subColumnVector(x)); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new SoftmaxBp(in, epsilon, in, -1)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationTanH.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationTanH.java index 038d6032d..a30b8e303 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationTanH.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationTanH.java @@ -18,11 +18,11 @@ package org.nd4j.linalg.activations.impl; import lombok.EqualsAndHashCode; import lombok.Getter; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; -import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative; import org.nd4j.linalg.factory.Nd4j; /** @@ -41,9 +41,10 @@ public class ActivationTanH extends BaseActivationFunction { @Override public Pair backprop(INDArray in, INDArray epsilon) { assertShape(in, epsilon); - INDArray dLdz = Nd4j.getExecutioner().exec(new TanhDerivative(in)); - dLdz.muli(epsilon); - return new Pair<>(dLdz, null); + + Nd4j.getExecutioner().execAndReturn(new TanhDerivative(in, epsilon, in)); + + return new Pair<>(in, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 46dd786c6..0d0af0788 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -2836,156 +2836,78 @@ public abstract class BaseNDArray implements INDArray, Iterable { return putScalar(i, element.getDouble(0)); } - /** - * In place addition of a column vector - * - * @param columnVector the column vector to add - * @return the result of the addition - */ @Override public INDArray diviColumnVector(INDArray columnVector) { validateNumericalArray("diviColumnVector", false); return doColumnWise(columnVector, 'd'); } - /** - * In place addition of a column vector - * - * @param columnVector the column vector to add - * @return the result of the addition - */ @Override public INDArray divColumnVector(INDArray columnVector) { validateNumericalArray("divColumnVector", false); return dup().diviColumnVector(columnVector); } - /** - * In place addition of a column vector - * - * @param rowVector the row vector to add - * @return the result of the addition - */ @Override public INDArray diviRowVector(INDArray rowVector) { validateNumericalArray("diviRowVector", false); return doRowWise(rowVector, 'd'); } - /** - * In place addition of a column vector - * - * @param rowVector the row vector to add - * @return the result of the addition - */ @Override public INDArray divRowVector(INDArray rowVector) { validateNumericalArray("divRowVector", false); return dup().diviRowVector(rowVector); } - /** - * In place addition of a column vector - * - * @param columnVector the column vector to add - * @return the result of the addition - */ @Override public INDArray muliColumnVector(INDArray columnVector) { validateNumericalArray("muliColumnVector", false); return doColumnWise(columnVector, 'm'); } - /** - * In place addition of a column vector - * - * @param columnVector the column vector to add - * @return the result of the addition - */ @Override public INDArray mulColumnVector(INDArray columnVector) { validateNumericalArray("mulColumnVector", false); return dup().muliColumnVector(columnVector); } - /** - * In place addition of a column vector - * - * @param rowVector the row vector to add - * @return the result of the addition - */ @Override public INDArray muliRowVector(INDArray rowVector) { validateNumericalArray("muliRowVector", false); return doRowWise(rowVector, 'm'); } - /** - * In place addition of a column vector - * - * @param rowVector the row vector to add - * @return the result of the addition - */ @Override public INDArray mulRowVector(INDArray rowVector) { validateNumericalArray("mulRowVector", false); return dup().muliRowVector(rowVector); } - /** - * In place addition of a column vector - * - * @param columnVector the column vector to add - * @return the result of the addition - */ @Override public INDArray subiColumnVector(INDArray columnVector) { validateNumericalArray("subiColumnVector", false); return doColumnWise(columnVector, 's'); } - /** - * In place addition of a column vector - * - * @param columnVector the column vector to add - * @return the result of the addition - */ @Override public INDArray subColumnVector(INDArray columnVector) { validateNumericalArray("subColumnVector", false); return dup().subiColumnVector(columnVector); } - /** - * In place addition of a column vector - * - * @param rowVector the row vector to add - * @return the result of the addition - */ @Override public INDArray subiRowVector(INDArray rowVector) { validateNumericalArray("subiRowVector", false); return doRowWise(rowVector, 's'); } - /** - * In place addition of a column vector - * - * @param rowVector the row vector to add - * @return the result of the addition - */ @Override public INDArray subRowVector(INDArray rowVector) { validateNumericalArray("subRowVector", false); return dup().subiRowVector(rowVector); } - /** - * In place addition of a column vector - * - * @param columnVector the column vector to add - * @return the result of the addition - */ @Override public INDArray addiColumnVector(INDArray columnVector) { validateNumericalArray("addiColumnVector", false); @@ -2997,24 +2919,12 @@ public abstract class BaseNDArray implements INDArray, Iterable { return doColumnWise(columnVector, 'p'); } - /** - * In place addition of a column vector - * - * @param columnVector the column vector to add - * @return the result of the addition - */ @Override public INDArray addColumnVector(INDArray columnVector) { validateNumericalArray("addColumnVector", false); return dup().addiColumnVector(columnVector); } - /** - * In place addition of a column vector - * - * @param rowVector the row vector to add - * @return the result of the addition - */ @Override public INDArray addiRowVector(INDArray rowVector) { validateNumericalArray("addiRowVector", false); @@ -3027,47 +2937,22 @@ public abstract class BaseNDArray implements INDArray, Iterable { return doRowWise(rowVector, 'p'); } - /** - * In place addition of a column vector - * - * @param rowVector the row vector to add - * @return the result of the addition - */ @Override public INDArray addRowVector(INDArray rowVector) { validateNumericalArray("addRowVector", false); return dup().addiRowVector(rowVector); } - - /** - * Perform a copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @return the result of the matrix multiplication - */ @Override public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) { return mMulTranspose.exec(this, other, result); } - /** - * Perform a copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @return the result of the matrix multiplication - */ @Override public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) { return mMulTranspose.exec(this, other, null); } - /** - * Perform a copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @return the result of the matrix multiplication - */ @Override public INDArray mmul(INDArray other) { Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType()); @@ -3105,8 +2990,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { if(!isVectorOrScalar()) { throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); } - - return dup().data().asDouble(); } @@ -3115,7 +2998,6 @@ public abstract class BaseNDArray implements INDArray, Iterable { if(!isVectorOrScalar()) { throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); } - return dup().data().asFloat(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java index da5cd3f60..116a4b4f7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseSparseNDArrayCOO.java @@ -1123,14 +1123,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray { return null; } - /** - * Perform an copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @param result the result ndarray - * @param mMulTranspose the transpose status of each array - * @return the result of the matrix multiplication - */ @Override public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) { return null; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java index 49cabf6bb..b842797f9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/INDArray.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* ***************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the @@ -16,9 +16,6 @@ package org.nd4j.linalg.api.ndarray; -import static org.nd4j.linalg.factory.Nd4j.compressDebug; -import static org.nd4j.linalg.factory.Nd4j.preventUnpack; - import com.google.flatbuffers.FlatBufferBuilder; import lombok.NonNull; import org.nd4j.linalg.api.blas.params.MMulTranspose; @@ -52,6 +49,7 @@ public interface INDArray extends Serializable, AutoCloseable { */ DataBuffer shapeInfoDataBuffer(); + // TODO: Unused untested method. /** * Sparse info * @return Sparse info. @@ -110,12 +108,13 @@ public interface INDArray extends Serializable, AutoCloseable { */ int elementWiseStride(); + // TODO: Unused untested method. /** * Get a double at the given linear offset unsafe, without checks. * @param offset the offset to get at * @return double value at offset */ - double getDoubleUnsafe(long offset); //TODO: consider deleting. + double getDoubleUnsafe(long offset); /** * Get string value at given index. @@ -124,13 +123,14 @@ public interface INDArray extends Serializable, AutoCloseable { */ String getString(long index); + // TODO: Unused untested method. /** * Insert a scalar at the given linear offset * @param offset the offset to insert at * @param value the value to insert * @return this */ - INDArray putScalarUnsafe(long offset, double value); //TODO: consider deleting. + INDArray putScalarUnsafe(long offset, double value); /** * Returns the number of possible vectors for a given dimension @@ -190,6 +190,7 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray assign(INDArray arr); + // TODO: Unused untested method. /** * Assign all elements from given ndarray that are matching given condition, * ndarray to this ndarray @@ -553,7 +554,7 @@ public interface INDArray extends Serializable, AutoCloseable { * * @param n the number to subtract by * @param result the result ndarray - * @return + * @return the result ndarray */ INDArray rsub(Number n, INDArray result); @@ -1041,7 +1042,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray divRowVector(INDArray rowVector); - /** * In place reverse divison of a column vector * @@ -1066,6 +1066,7 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray rdiviRowVector(INDArray rowVector); + //TODO: unused / untested method. /** * Reverse division of a column vector (copy) * @@ -1074,7 +1075,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray rdivRowVector(INDArray rowVector); - /** * In place multiplication of a column vector * @@ -1107,7 +1107,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray mulRowVector(INDArray rowVector); - /** * In place reverse subtraction of a column vector * @@ -1132,6 +1131,7 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray rsubiRowVector(INDArray rowVector); + //TODO: unused / untested method. /** * Reverse subtraction of a row vector (copy) * @@ -1180,7 +1180,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray addiColumnVector(INDArray columnVector); - /** * In place assignment of a column vector * @@ -1221,6 +1220,12 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray addRowVector(INDArray rowVector); + /** + * Perform a copy matrix multiplication + * + * @param other the other matrix to perform matrix multiply with + * @return the result of the matrix multiplication + */ INDArray mmul(INDArray other, MMulTranspose mMulTranspose); /** @@ -1231,8 +1236,6 @@ public interface INDArray extends Serializable, AutoCloseable { */ INDArray mmul(INDArray other); - - /** * Convert this ndarray to a 2d double matrix. * Note that THIS SHOULD NOT BE USED FOR SPEED. @@ -1283,6 +1286,14 @@ public interface INDArray extends Serializable, AutoCloseable { */ int[] toIntVector(); + /** + * Convert this ndarray to a 1d long matrix. + * Note that THIS SHOULD NOT BE USED FOR SPEED. + * This is mainly used for integrations with other libraries. + * Due to nd4j's off heap nature, moving data on heap is very expensive + * and should not be used if possible. + * @return a copy of this array as a 1d long array + */ long[] toLongVector(); /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java index 9f600b29b..fe70de288 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/LeakyReLU.java @@ -16,17 +16,13 @@ package org.nd4j.linalg.api.ops.impl.scalar; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.graph.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseScalarOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; - -import java.util.Arrays; -import java.util.LinkedHashMap; +import java.util.Collections; import java.util.List; import java.util.Map; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.BaseScalarOp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -108,8 +104,7 @@ public class LeakyReLU extends BaseScalarOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().leakyReluDerivative(arg(), alpha).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(f().leakyReluBp(arg(), i_v.get(0), alpha)); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java index 98fa587b5..ca8cee2f1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinear.java @@ -75,15 +75,8 @@ public class RectifiedLinear extends BaseScalarOp { return "Relu"; } - @Override public List doDiff(List i_v) { - if(scalarValue.getDouble(0) == 0.0){ - return Collections.singletonList(f().reluDerivative(arg(), i_v.get(0))); - } else { - SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0]; - SDVariable ret = step.mul(i_v.get(0)); - return Collections.singletonList(ret); - } + return Collections.singletonList(f().thresholdReluBp(arg(), i_v.get(0), scalarValue.getDouble(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java index 3af7a4190..7e4d0fa09 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/scalar/RectifiedLinearDerivative.java @@ -3,10 +3,10 @@ package org.nd4j.linalg.api.ops.impl.scalar; import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.shade.guava.base.Preconditions; import java.util.Collections; import java.util.List; @@ -30,7 +30,8 @@ public class RectifiedLinearDerivative extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List dataTypes) { - Preconditions.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); return Collections.singletonList(dataTypes.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java new file mode 100644 index 000000000..82e2ae6e3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ThresholdRelu.java @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import java.util.Collections; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; + +/** + * Threshold ReLU op. The genral case of {@link RectifiedLinear}. + */ +public class ThresholdRelu extends DynamicCustomOp { + + @Getter + private double cutoff = 0.0; + + public ThresholdRelu(){ } + + public ThresholdRelu(SameDiff sd, SDVariable input, boolean inPlace, double cutoff){ + super(sd, new SDVariable[]{input}, inPlace); + this.cutoff = cutoff; + addTArgument(cutoff); + } + + public ThresholdRelu(SameDiff sd, SDVariable input, double cutoff){ + super(sd, new SDVariable[]{input}); + this.cutoff = cutoff; + addTArgument(cutoff); + } + + public ThresholdRelu(@NonNull INDArray input, INDArray output, double cutoff){ + super(new INDArray[]{input}, wrapOrNull(output)); + this.cutoff = cutoff; + addTArgument(cutoff); + } + + @Override + public String opName(){ + return "thresholdedrelu"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + return Collections.singletonList(f().thresholdReluBp(arg(), f1.get(0), cutoff)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeBp.java new file mode 100644 index 000000000..16f67c910 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * Cube backpropagation op - dL/dIn from in and dL/dOut + */ +public class CubeBp extends DynamicCustomOp { + + public CubeBp(){ } + + public CubeBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public CubeBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "cube_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeDerivative.java index 9c6a000c1..af9985e89 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/CubeDerivative.java @@ -27,7 +27,11 @@ import java.util.List; /** * Cube derivative, e.g. 3x^2 + * + * @deprecated Use {@link CubeBp} + * */ +@Deprecated public class CubeDerivative extends BaseTransformStrictOp { public CubeDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ELUDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ELUDerivative.java deleted file mode 100644 index 45357fb0b..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ELUDerivative.java +++ /dev/null @@ -1,84 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.transforms.gradient; - -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformStrictOp; - -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -/** - * - * Derivative of ELU: Exponential Linear Unit (alpha=1.0)
- * Introduced in paper:
- * Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
- * Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter (2015)
- * http://arxiv.org/abs/1511.07289 - * - * @author Alex Black - */ -public class ELUDerivative extends BaseTransformStrictOp { - public ELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); - } - - public ELUDerivative() { - - } - - public ELUDerivative(INDArray x, INDArray z) { - super(x, z); - } - - public ELUDerivative(INDArray x) { - super(x); - } - - @Override - public int opNum() { - return 3; - } - - @Override - public String opName() { - return "eluderivative"; - } - - @Override - public String onnxName() { - throw new NoOpNameFoundException("No onnx op opName found for " + opName()); - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - - - @Override - public List doDiff(List i_v) { - SDVariable ret = sameDiff.zerosLike(arg()); - return Collections.singletonList(ret); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java new file mode 100644 index 000000000..f4624a6ee --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * ELU backpropagation op - dL/dIn from in and dL/dOut + */ +public class EluBp extends DynamicCustomOp { + + public EluBp(){ } + + public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) { + this(input, gradient, output, 1.0); + } + + public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double alpha){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + addTArgument(alpha); + } + + @Override + public String opName(){ + return "elu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidBp.java new file mode 100644 index 000000000..7bf905c5d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * Hard Sigmoid backpropagation op - dL/dIn from in and dL/dOut + */ +public class HardSigmoidBp extends DynamicCustomOp { + + public HardSigmoidBp(){ } + + public HardSigmoidBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public HardSigmoidBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "hardsigmoid_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidDerivative.java index 01420d98a..c7328a92b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardSigmoidDerivative.java @@ -29,8 +29,11 @@ import java.util.List; /** * HardSigmoid derivative * + * @deprecated Use {@link HardSigmoidBp} + * * @author raver119@gmail.com */ +@Deprecated public class HardSigmoidDerivative extends BaseTransformStrictOp { public HardSigmoidDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhBp.java new file mode 100644 index 000000000..10102eb2b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * Hard Tanh backpropagation op - dL/dIn from in and dL/dOut + */ +public class HardTanhBp extends DynamicCustomOp { + + public HardTanhBp(){ } + + public HardTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public HardTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "hardtanh_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java index eb0a28e09..e5322c02f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/HardTanhDerivative.java @@ -31,8 +31,11 @@ import java.util.List; /** * Hard tanh elementwise derivative function * + * @deprecated Use {@link HardTanhBp} + * * @author Adam Gibson */ +@Deprecated public class HardTanhDerivative extends BaseTransformStrictOp { public HardTanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUBp.java new file mode 100644 index 000000000..60ef40423 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/LeakyReLUBp.java @@ -0,0 +1,68 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * LReLU backpropagation op - dL/dIn from in and dL/dOut + */ +public class LeakyReLUBp extends DynamicCustomOp { + public static final double DEFAULT_ALPHA = 0.01; + private double alpha = DEFAULT_ALPHA; + + public LeakyReLUBp(){ } + + public LeakyReLUBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){ + super(sd, new SDVariable[]{input, gradient}); + this.alpha = alpha; + addTArgument(alpha); + } + + public LeakyReLUBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double alpha){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + this.alpha = alpha; + addTArgument(alpha); + } + + @Override + public String opName(){ + return "lrelu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhBp.java new file mode 100644 index 000000000..f70d79ae1 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * Rational Tanh backpropagation op - dL/dIn from in and dL/dOut + */ +public class RationalTanhBp extends DynamicCustomOp { + + public RationalTanhBp(){ } + + public RationalTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public RationalTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "rationaltanh_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhDerivative.java index 4f0f5915e..18443e9bb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RationalTanhDerivative.java @@ -31,9 +31,12 @@ import java.util.List; * Rational Tanh Derivative, as described at https://github.com/deeplearning4j/libnd4j/issues/351 * Calculates dOut/dIn given input, not dL/dIn given dL/dOut and input * + * @deprecated Use {@link RationalTanhBp} + * * @author raver119@gmail.com * @author AlexDBlack */ +@Deprecated public class RationalTanhDerivative extends BaseTransformStrictOp { public RationalTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) { super(sameDiff, in, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhBp.java new file mode 100644 index 000000000..c10d6071a --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * Rectified Tanh backpropagation op - dL/dIn from in and dL/dOut + */ +public class RectifiedTanhBp extends DynamicCustomOp { + + public RectifiedTanhBp(){ } + + public RectifiedTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public RectifiedTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "rectifiedtanh_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhDerivative.java index 37acf94ac..8c896fb10 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/RectifiedTanhDerivative.java @@ -30,9 +30,12 @@ import java.util.List; /** * Rectified Tanh Derivative * + * @deprecated Use {@link RectifiedTanhBp} + * * @author raver119@gmail.com * @author AlexDBlack */ +@Deprecated public class RectifiedTanhDerivative extends BaseTransformStrictOp { public RectifiedTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) { super(sameDiff, in, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/Relu6Derivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/Relu6Derivative.java index c915658b8..3477b4e71 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/Relu6Derivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/Relu6Derivative.java @@ -16,15 +16,18 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; import java.util.List; +import org.nd4j.linalg.api.ops.impl.transforms.same.Identity; /** * Derivative of Rectified linear unit 6, i.e. min(max(input, cutoff), 6), where cutoff can be chosen. @@ -33,7 +36,9 @@ import java.util.List; */ public class Relu6Derivative extends DynamicCustomOp { - private double cutoff = 0.0; + private static final double DEFAULT_CUTOFF = 0.0; + + private double cutoff = DEFAULT_CUTOFF; public Relu6Derivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, double cutoff) { super("relu6_bp", sameDiff, new SDVariable[]{i_v1, i_v2}); @@ -45,6 +50,16 @@ public class Relu6Derivative extends DynamicCustomOp { this.extraArgs = new Object[]{cutoff}; } + public Relu6Derivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + this(input, gradient, output, DEFAULT_CUTOFF); + } + + public Relu6Derivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double cutoff){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + this.cutoff = cutoff; + this.extraArgs = new Object[]{cutoff}; + } + @Override public int opNum() { return 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java index 58877f041..b00b29b75 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SELUDerivative.java @@ -31,8 +31,11 @@ import java.util.List; * * https://arxiv.org/pdf/1706.02515.pdf * + * @deprecated Use {@link SeluBp} + * * @author raver119@gmail.com */ +@Deprecated public class SELUDerivative extends BaseTransformStrictOp { private static final double SELU_ALPHA = 1.6732632423543772848170429916717; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SeluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SeluBp.java new file mode 100644 index 000000000..a13171e10 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SeluBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * SELU backpropagation op - dL/dIn from in and dL/dOut + */ +public class SeluBp extends DynamicCustomOp { + + public SeluBp(){ } + + public SeluBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public SeluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "selu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftPlusBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftPlusBp.java new file mode 100644 index 000000000..be8c1b702 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftPlusBp.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * SoftPlus backpropagation op - dL/dIn from in and dL/dOut + */ +public class SoftPlusBp extends DynamicCustomOp { + + public SoftPlusBp(){ } + + public SoftPlusBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public SoftPlusBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "softplus_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignBp.java new file mode 100644 index 000000000..c636361e6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignBp.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +/** + * SoftSign backpropagation op - dL/dIn from in and dL/dOut + */ +public class SoftSignBp extends DynamicCustomOp { + + public SoftSignBp(){ } + + public SoftSignBp(SameDiff sd, SDVariable input, SDVariable gradient){ + super(sd, new SDVariable[]{input, gradient}); + } + + public SoftSignBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + } + + @Override + public String opName(){ + return "softsign_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java index 3741cfe90..4ae26e585 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftSignDerivative.java @@ -29,7 +29,10 @@ import java.util.List; /** * SoftSign derivative. + * + * @deprecated Use {@link SoftSignBp} */ +@Deprecated public class SoftSignDerivative extends BaseTransformStrictOp { public SoftSignDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java index 1a018d2e0..dbbdb8dde 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/SoftmaxBp.java @@ -16,10 +16,12 @@ package org.nd4j.linalg.api.ops.impl.transforms.gradient; +import lombok.NonNull; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import java.util.Collections; @@ -40,6 +42,12 @@ public class SoftmaxBp extends DynamicCustomOp { addIArgument(dimension); } + public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Integer dimension){ + super(new INDArray[]{input, grad}, wrapOrNull(output)); + if(dimension != null) + addIArgument(dimension); + } + @Override public String opName() { return "softmax_bp"; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java index be0f7d85f..4d8209b8a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/TanhDerivative.java @@ -35,15 +35,15 @@ public class TanhDerivative extends DynamicCustomOp { super(sameDiff, new SDVariable[]{i_v1, i_v2}); } - public TanhDerivative(INDArray x, INDArray z) { - super(null, x, z, null, null); + public TanhDerivative(INDArray x, INDArray y, INDArray z) { + super(null, new INDArray[]{x, y}, new INDArray[]{z}); } public TanhDerivative() { } - public TanhDerivative(INDArray x) { - this(x, null); + public TanhDerivative(INDArray x, INDArray y) { + this(x, y, null); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ThresholdReluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ThresholdReluBp.java new file mode 100644 index 000000000..8d04a7118 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ThresholdReluBp.java @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.nd4j.linalg.api.ops.impl.transforms.gradient; + +import java.util.Collections; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; +import org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu; + +/** + * Threshold ReLU Backprop op - dL/dIn from in and dL/dOut + * + * For {@link RectifiedLinear} as well as {@link ThresholdRelu}. + */ +public class ThresholdReluBp extends DynamicCustomOp { + + @Getter + private double cutoff = 0; + + public ThresholdReluBp(){ } + + public ThresholdReluBp(SameDiff sd, SDVariable input, SDVariable gradient, double cutoff){ + super(sd, new SDVariable[]{input, gradient}); + this.cutoff = cutoff; + addTArgument(cutoff); + } + + public ThresholdReluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double cutoff){ + super(new INDArray[]{input, gradient}, wrapOrNull(output)); + this.cutoff = cutoff; + addTArgument(cutoff); + } + + @Override + public String opName(){ + return "thresholdedrelu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List f1) { + throw new UnsupportedOperationException("Not supported"); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java index 9e1e05693..d58ad8f3f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Cube.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.same; +import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; @@ -70,7 +71,6 @@ public class Cube extends BaseTransformSameOp { @Override public List doDiff(List f1) { - SDVariable g = f().mul(f().cubeDerivative(arg()),f1.get(0)); - return Arrays.asList(g); + return Collections.singletonList(f().cubeBp(arg(), f1.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java index 8266add39..a144e868b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java @@ -18,14 +18,18 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; -import org.nd4j.linalg.api.ops.BaseTransformStrictOp; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; -import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.Map; /** * ELU: Exponential Linear Unit (alpha=1.0)
@@ -36,25 +40,20 @@ import java.util.List; * * @author Alex Black */ -public class ELU extends BaseTransformStrictOp { - public ELU(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); +public class ELU extends DynamicCustomOp { + public ELU(SameDiff sameDiff, SDVariable i_v) { + super(sameDiff, new SDVariable[]{i_v}); } public ELU() { } public ELU(INDArray x, INDArray z) { - super(x, z); + super(null, wrapOrNull(x), wrapOrNull(z)); } public ELU(INDArray x) { - super(x); - } - - @Override - public int opNum() { - return 35; + this(x, null); } @Override @@ -76,8 +75,14 @@ public class ELU extends BaseTransformStrictOp { public List doDiff(List i_v) { //ELU: e^x-1 if x<0, x otherwise //dL/dIn = dL/Out * dOut/dIn - SDVariable ret = f().eluDerivative(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(f().eluBp(arg(), i_v.get(0))); } + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 datatype for ELU, got %s", dataTypes); + Preconditions.checkState(dataTypes.get(0).isFPType(), "Expected floating point input type for ELU, got %s", dataTypes); + + return dataTypes; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java index a1703d221..ddca48d4c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardSigmoid.java @@ -69,9 +69,7 @@ public class HardSigmoid extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - SDVariable in = arg(); - SDVariable dOutdIn = new HardSigmoidDerivative(sameDiff, in, false).outputVariables()[0]; - return Collections.singletonList(dOutdIn.mul(f1.get(0))); + return Collections.singletonList(f().hardSigmoidBp(arg(), f1.get(0))); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java index a2452443b..4237e72de 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/HardTanh.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -70,7 +71,6 @@ public class HardTanh extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().hardTanhDerivative(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(f().hardTanhBp(arg(), i_v.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java index 2de0e90a7..a05e34637 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RationalTanh.java @@ -16,13 +16,10 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; -import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import java.util.Collections; @@ -71,6 +68,6 @@ public class RationalTanh extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().tanhRationalDerivative(arg()).mul(f1.get(0))); + return Collections.singletonList(f().tanhRationalBp(arg(), f1.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java index da439cec7..d5fbf1294 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/RectifiedTanh.java @@ -17,13 +17,10 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; import onnx.Onnx; -import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseTransformFloatOp; -import org.nd4j.linalg.api.ops.BaseTransformOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -88,6 +85,6 @@ public class RectifiedTanh extends BaseTransformStrictOp { @Override public List doDiff(List f1) { - return Collections.singletonList(f().tanhRectifiedDerivative(arg()).mul(f1.get(0))); + return Collections.singletonList(f().tanhRectifiedBp(arg(), f1.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java index 159b0b170..f72676f86 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SELU.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -76,8 +77,7 @@ public class SELU extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().seluDerivative(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(f().seluBp(arg(), i_v.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SigmoidDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SigmoidDerivative.java index 7213b50b0..08a97bae7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SigmoidDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SigmoidDerivative.java @@ -28,8 +28,11 @@ import java.util.List; /** * Sigmoid derivative * + * @deprecated Use {@link org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative} + * * @author Adam Gibson */ +@Deprecated public class SigmoidDerivative extends BaseTransformStrictOp { public SigmoidDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { super(sameDiff, i_v1, i_v2); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java index 0b2782860..c7c90b201 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/SoftSign.java @@ -16,6 +16,7 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict; +import java.util.Collections; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.linalg.api.ndarray.INDArray; @@ -73,8 +74,7 @@ public class SoftSign extends BaseTransformStrictOp { @Override public List doDiff(List i_v) { - SDVariable ret = f().softsignDerivative(arg()).mul(i_v.get(0)); - return Arrays.asList(ret); + return Collections.singletonList(f().softsignBp(arg(), i_v.get(0))); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/TanhDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/TanhDerivative.java index fc9a9581f..fad63e73b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/TanhDerivative.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/TanhDerivative.java @@ -27,7 +27,10 @@ import java.util.List; /** * Tanh derivative + * + * @deprecated Use {@link org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative}. */ +@Deprecated public class TanhDerivative extends BaseTransformStrictOp { public TanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { super(sameDiff, i_v, inPlace); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java index 1f4004cb2..ad887789d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java @@ -37,7 +37,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.floating.*; import org.nd4j.linalg.api.ops.impl.transforms.comparison.*; -import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative; +import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; @@ -438,16 +438,16 @@ public class Transforms { public static INDArray elu(INDArray in, boolean copy) { - return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in))); + return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)))[0]; } - public static INDArray eluDerivative(INDArray arr) { - return eluDerivative(arr, true); + public static INDArray eluDerivative(INDArray arr, INDArray grad) { + return eluDerivative(arr, grad,true); } - public static INDArray eluDerivative(INDArray in, boolean copy) { - return Nd4j.getExecutioner().exec(new ELUDerivative(in, (copy ? in.ulike() : in))); + public static INDArray eluDerivative(INDArray in, INDArray grad, boolean copy) { + return Nd4j.getExecutioner().exec(new EluBp(in, grad, (copy ? in.ulike() : in)))[0]; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/LecunUniformInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/LecunUniformInitScheme.java index 67ff1a114..73b471535 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/LecunUniformInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/LecunUniformInitScheme.java @@ -42,7 +42,7 @@ public class LecunUniformInitScheme extends BaseWeightInitScheme { @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double b = 3.0 / Math.sqrt(fanIn); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-b, b)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-b, b), shape); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/ReluUniformInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/ReluUniformInitScheme.java index 9561953e0..07eeadeb7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/ReluUniformInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/ReluUniformInitScheme.java @@ -43,7 +43,7 @@ public class ReluUniformInitScheme extends BaseWeightInitScheme { @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double u = Math.sqrt(6.0 / fanIn); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn) + return Nd4j.rand(Nd4j.getDistributions().createUniform(-u, u), shape); //U(-sqrt(6/fanIn), sqrt(6/fanIn) } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/SigmoidUniformInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/SigmoidUniformInitScheme.java index 58809a095..4c7420b38 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/SigmoidUniformInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/SigmoidUniformInitScheme.java @@ -46,7 +46,7 @@ public class SigmoidUniformInitScheme extends BaseWeightInitScheme { @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut)); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-r, r)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-r, r), shape); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/UniformInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/UniformInitScheme.java index c8744d69e..ca44b3a84 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/UniformInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/UniformInitScheme.java @@ -43,7 +43,7 @@ public class UniformInitScheme extends BaseWeightInitScheme { @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double a = 1.0 / Math.sqrt(fanIn); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-a, a)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-a, a), shape); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanInInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanInInitScheme.java index 1ed0e4efb..3ad193140 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanInInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanInInitScheme.java @@ -43,7 +43,7 @@ public class VarScalingNormalUniformFanInInitScheme extends BaseWeightInitScheme @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double scalingFanIn = 3.0 / Math.sqrt(fanIn); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-scalingFanIn, scalingFanIn), shape); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanOutInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanOutInitScheme.java index 2405dea88..bafa2170d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanOutInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingNormalUniformFanOutInitScheme.java @@ -42,7 +42,7 @@ public class VarScalingNormalUniformFanOutInitScheme extends BaseWeightInitSchem @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double scalingFanOut = 3.0 / Math.sqrt(fanOut); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-scalingFanOut, scalingFanOut), shape); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingUniformFanAvgInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingUniformFanAvgInitScheme.java index 8f3dc2a7d..2e3d85093 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingUniformFanAvgInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/VarScalingUniformFanAvgInitScheme.java @@ -46,7 +46,7 @@ public class VarScalingUniformFanAvgInitScheme extends BaseWeightInitScheme { @Override public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { double scalingFanAvg = 3.0 / Math.sqrt((fanIn + fanOut) / 2); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-scalingFanAvg, scalingFanAvg), shape); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/XavierUniformInitScheme.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/XavierUniformInitScheme.java index ddc0d7428..d1f156b01 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/XavierUniformInitScheme.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/weightinit/impl/XavierUniformInitScheme.java @@ -46,7 +46,7 @@ public class XavierUniformInitScheme extends BaseWeightInitScheme { //As per Glorot and Bengio 2010: Uniform distribution U(-s,s) with s = sqrt(6/(fanIn + fanOut)) //Eq 16: http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf double s = Math.sqrt(6.0) / Math.sqrt(fanIn + fanOut); - return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-s, s)); + return Nd4j.rand(Nd4j.getDistributions().createUniform(-s, s), shape); } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JcusparseNDArrayCSR.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JcusparseNDArrayCSR.java index c9266e9e7..2698c299f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JcusparseNDArrayCSR.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/JcusparseNDArrayCSR.java @@ -68,14 +68,6 @@ public class JcusparseNDArrayCSR extends BaseSparseNDArrayCSR { return null; } - /** - * Perform an copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @param result the result ndarray - * @param mMulTranspose the transpose status of each array - * @return the result of the matrix multiplication - */ @Override public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) { return null; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/SparseNDArrayCSR.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/SparseNDArrayCSR.java index 0135ef2c5..b35662b98 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/SparseNDArrayCSR.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/SparseNDArrayCSR.java @@ -79,14 +79,6 @@ public class SparseNDArrayCSR extends BaseSparseNDArrayCSR { return null; } - /** - * Perform an copy matrix multiplication - * - * @param other the other matrix to perform matrix multiply with - * @param result the result ndarray - * @param mMulTranspose the transpose status of each array - * @return the result of the matrix multiplication - */ @Override public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) { return null; diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index eeb4d38c3..6983e20f0 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -12859,7 +12859,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * This is Concatenated RELU implementation. * What happens inside: RELU(Concat((x, -x, {-1}))) - * + * * PLEASE NOTE: Concatenation will double amount of features available in input */ // #if NOT_EXCLUDED(OP_crelu) diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java index a2dd3ff5d..1da31d863 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java @@ -52,7 +52,8 @@ public class TFGraphTestList { public TemporaryFolder testDir = new TemporaryFolder(); public static String[] modelNames = new String[]{ - "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME" +// "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME" + "cnn2d_layers/channels_last_b1_k2_s1_d1_SAME_elu" }; @After diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index f325348fb..ded23f810 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -788,4 +788,23 @@ public class CustomOpsTests extends BaseNd4jTest { Nd4j.exec(op); Nd4j.getExecutioner().commit(); } + + @Test + public void test() throws Exception { + + INDArray in1 = Nd4j.create(DataType.BFLOAT16, 2, 3, 10, 1);//Nd4j.createFromArray(0.2019043,0.6464844,0.9116211,0.60058594,0.34033203,0.7036133,0.6772461,0.3815918,0.87353516,0.04650879,0.67822266,0.8618164,0.88378906,0.7573242,0.66796875,0.63427734,0.33764648,0.46923828,0.62939453,0.76464844,-0.8618164,-0.94873047,-0.9902344,-0.88916016,-0.86572266,-0.92089844,-0.90722656,-0.96533203,-0.97509766,-0.4975586,-0.84814453,-0.984375,-0.98828125,-0.95458984,-0.9472656,-0.91064453,-0.80859375,-0.83496094,-0.9140625,-0.82470703,0.4802246,0.45361328,0.28125,0.28320312,0.79345703,0.44604492,-0.30273438,0.11730957,0.56396484,0.73583984,0.1418457,-0.44848633,0.6923828,-0.40234375,0.40185547,0.48632812,0.14538574,0.4638672,0.13000488,0.5058594) + //.castTo(DataType.BFLOAT16).reshape(2,3,10,1); + INDArray in2 = Nd4j.create(DataType.BFLOAT16, 2, 3, 10, 1); //Nd4j.createFromArray(0.0,-0.13391113,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.1751709,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.51904297,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.5107422,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0) + //.castTo(DataType.BFLOAT16).reshape(2,3,10,1); + + INDArray out = in1.ulike(); + + Nd4j.exec(DynamicCustomOp.builder("maxpool2d_bp") + .addInputs(in1, in2) + .addOutputs(out) + .addIntegerArguments(5,1,1,2,2,0,1,1,1,0,0) + .build()); + + Nd4j.getExecutioner().commit(); + } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java index ff9582378..5a51b847d 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java @@ -305,44 +305,6 @@ public class DerivativeTests extends BaseNd4jTest { } } - @Test - public void testELUDerivative() { - - //f(x) = x if x>=0 - //f(x) = 1.0*(exp(x)-1) - INDArray z = Nd4j.zeros(100); - double[] out = new double[100]; - double[] outDeriv = new double[100]; - for (int i = 0; i < 100; i++) { - double x = 0.1 * (i - 50); - z.putScalar(i, x); - if (x >= 0) { - out[i] = x; - outDeriv[i] = 1.0; - } else { - out[i] = FastMath.exp(x) - 1.0; - outDeriv[i] = FastMath.exp(x); - } - } - - INDArray act = Transforms.elu(z, true); - INDArray actDeriv = Nd4j.getExecutioner().exec(new ELUDerivative(z.dup())); - - System.out.println(act); - - for (int i = 0; i < 100; i++) { - double relError1 = Math.abs(out[i] - act.getDouble(i)) / (Math.abs(out[i]) + Math.abs(act.getDouble(i))); - if (out[i] == 0.0 && act.getDouble(i) == 0.0) - relError1 = 0.0; - double relError2 = Math.abs(outDeriv[i] - actDeriv.getDouble(i)) - / (Math.abs(outDeriv[i]) + Math.abs(actDeriv.getDouble(i))); - if (outDeriv[i] == 0.0 && actDeriv.getDouble(i) == 0.0) - relError2 = 0.0; - assertTrue(relError1 < REL_ERROR_TOLERANCE); - assertTrue(relError2 < REL_ERROR_TOLERANCE); - } - } - @Override public char ordering() { return 'f';