Merge remote-tracking branch 'fork/master'
commit
7ded4416cb
|
@ -248,7 +248,7 @@ public class ValidateCuDNN extends BaseDL4JTest {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
INDArray features = Nd4j.rand(fShape);
|
INDArray features = Nd4j.rand(fShape);
|
||||||
INDArray labels = Nd4j.rand(lShape);
|
INDArray labels = Nd4j.rand(lShape);
|
||||||
labels = Nd4j.exec(new IsMax(labels, 1));
|
labels = Nd4j.exec(new IsMax(labels, 1))[0].castTo(features.dataType());
|
||||||
|
|
||||||
List<CuDNNValidationUtil.TestCase> testCaseList = new ArrayList<>();
|
List<CuDNNValidationUtil.TestCase> testCaseList = new ArrayList<>();
|
||||||
|
|
||||||
|
@ -256,7 +256,7 @@ public class ValidateCuDNN extends BaseDL4JTest {
|
||||||
for (int i = 0; i < 6; i++) {
|
for (int i = 0; i < 6; i++) {
|
||||||
INDArray f = Nd4j.rand(fShape);
|
INDArray f = Nd4j.rand(fShape);
|
||||||
INDArray l = Nd4j.rand(lShape);
|
INDArray l = Nd4j.rand(lShape);
|
||||||
Nd4j.exec(new IsMax(l, 1))[0];
|
l = Nd4j.exec(new IsMax(l, 1))[0].castTo(features.dataType());
|
||||||
dataSets.add(new DataSet(f, l));
|
dataSets.add(new DataSet(f, l));
|
||||||
}
|
}
|
||||||
DataSetIterator iter = new ExistingDataSetIterator(dataSets);
|
DataSetIterator iter = new ExistingDataSetIterator(dataSets);
|
||||||
|
|
|
@ -25,7 +25,6 @@ import org.deeplearning4j.models.word2vec.Huffman;
|
||||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction;
|
import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction;
|
||||||
import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunctionAdapter;
|
|
||||||
import org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction;
|
import org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction;
|
||||||
import org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec;
|
import org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec;
|
||||||
import org.deeplearning4j.spark.text.functions.CountCumSum;
|
import org.deeplearning4j.spark.text.functions.CountCumSum;
|
||||||
|
@ -470,11 +469,11 @@ public class TextPipelineTest extends BaseSparkTest {
|
||||||
|
|
||||||
Iterator<Tuple2<List<VocabWord>, Long>> iterator = vocabWordListSentenceCumSumRDD.collect().iterator();
|
Iterator<Tuple2<List<VocabWord>, Long>> iterator = vocabWordListSentenceCumSumRDD.collect().iterator();
|
||||||
|
|
||||||
FirstIterationFunctionAdapter firstIterationFunction = new FirstIterationFunctionAdapter(
|
FirstIterationFunction firstIterationFunction = new FirstIterationFunction(
|
||||||
word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache());
|
word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache());
|
||||||
|
|
||||||
Iterable<Map.Entry<VocabWord, INDArray>> ret = firstIterationFunction.call(iterator);
|
Iterator<Map.Entry<VocabWord, INDArray>> ret = firstIterationFunction.call(iterator);
|
||||||
assertTrue(ret.iterator().hasNext());
|
assertTrue(ret.hasNext());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -70,6 +70,13 @@
|
||||||
<artifactId>deeplearning4j-play_2.11</artifactId>
|
<artifactId>deeplearning4j-play_2.11</artifactId>
|
||||||
<version>${deeplearning4j.version}</version>
|
<version>${deeplearning4j.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
|
<exclusions>
|
||||||
|
<!-- To avoid clashing net.jpountz.lz4:lz4:1.3.0 and org.lz4:lz4-java:jar:1.4.0 - -->
|
||||||
|
<exclusion>
|
||||||
|
<groupId>net.jpountz.lz4</groupId>
|
||||||
|
<artifactId>lz4</artifactId>
|
||||||
|
</exclusion>
|
||||||
|
</exclusions>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
<dependency>
|
<dependency>
|
||||||
|
|
|
@ -273,9 +273,11 @@ namespace nd4j {
|
||||||
* @param writeList
|
* @param writeList
|
||||||
* @param readList
|
* @param readList
|
||||||
*/
|
*/
|
||||||
|
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||||
static void registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
|
static void registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
|
||||||
static void prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
|
static void prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
|
||||||
|
|
||||||
|
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||||
static void registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
|
static void registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
|
||||||
static void preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
|
static void preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
|
||||||
|
|
||||||
|
|
|
@ -931,13 +931,13 @@ void initializeFunctions(Nd4jPointer *functions) {
|
||||||
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
|
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
|
||||||
Nd4jPointer pointer;
|
Nd4jPointer pointer;
|
||||||
// cudaHostAllocMapped |cudaHostAllocPortable
|
// cudaHostAllocMapped |cudaHostAllocPortable
|
||||||
auto res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize, cudaHostAllocDefault);
|
auto res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize + 8, cudaHostAllocDefault);
|
||||||
if (res != 0) {
|
if (res != 0) {
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res);
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res);
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaHostAlloc failed");
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaHostAlloc failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
return pointer;
|
return reinterpret_cast<int8_t*>(pointer);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -950,13 +950,13 @@ Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
|
||||||
*/
|
*/
|
||||||
Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
|
Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
|
||||||
Nd4jPointer pointer;
|
Nd4jPointer pointer;
|
||||||
auto res = cudaMalloc(reinterpret_cast<void **>(&pointer), memorySize);
|
auto res = cudaMalloc(reinterpret_cast<void **>(&pointer), memorySize + 8);
|
||||||
if (res != 0) {
|
if (res != 0) {
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res);
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res);
|
||||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMalloc failed");
|
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMalloc failed");
|
||||||
}
|
}
|
||||||
|
|
||||||
return pointer;
|
return reinterpret_cast<int8_t*>(pointer);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -96,13 +96,13 @@ namespace functions {
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if(vx == vz) {
|
if(vx == vz) {
|
||||||
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
|
for (Nd4jLong i = tid; i < length; i+= totalThreads) {
|
||||||
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
||||||
z[xOffset] = OpType::op(x[xOffset], params);
|
z[xOffset] = OpType::op(x[xOffset], params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
|
for (Nd4jLong i = tid; i < length; i+= totalThreads) {
|
||||||
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
||||||
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
|
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
|
||||||
z[zOffset] = OpType::op(x[xOffset], params);
|
z[zOffset] = OpType::op(x[xOffset], params);
|
||||||
|
|
|
@ -94,13 +94,13 @@ namespace functions {
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if(vx == vz) {
|
if(vx == vz) {
|
||||||
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
|
for (Nd4jLong i = tid; i < length; i+= totalThreads) {
|
||||||
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
||||||
z[xOffset] = OpType::op(x[xOffset], params);
|
z[xOffset] = OpType::op(x[xOffset], params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
|
for (Nd4jLong i = tid; i < length; i+= totalThreads) {
|
||||||
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
||||||
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
|
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
|
||||||
z[zOffset] = OpType::op(x[xOffset], params);
|
z[zOffset] = OpType::op(x[xOffset], params);
|
||||||
|
|
|
@ -96,13 +96,13 @@ namespace functions {
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
if(vx == vz) {
|
if(vx == vz) {
|
||||||
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
|
for (Nd4jLong i = tid; i < length; i+= totalThreads) {
|
||||||
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
||||||
z[xOffset] = OpType::op(x[xOffset], params);
|
z[xOffset] = OpType::op(x[xOffset], params);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) {
|
for (Nd4jLong i = tid; i < length; i+= totalThreads) {
|
||||||
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
|
||||||
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
|
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
|
||||||
z[zOffset] = OpType::op(x[xOffset], params);
|
z[zOffset] = OpType::op(x[xOffset], params);
|
||||||
|
|
|
@ -116,7 +116,6 @@
|
||||||
|
|
||||||
|
|
||||||
#define TRANSFORM_STRICT_OPS \
|
#define TRANSFORM_STRICT_OPS \
|
||||||
(3, ELUDerivative), \
|
|
||||||
(4, TanhDerivative), \
|
(4, TanhDerivative), \
|
||||||
(5, HardTanhDerivative), \
|
(5, HardTanhDerivative), \
|
||||||
(6, SigmoidDerivative), \
|
(6, SigmoidDerivative), \
|
||||||
|
@ -148,7 +147,6 @@
|
||||||
(32, ATan), \
|
(32, ATan), \
|
||||||
(33, HardTanh), \
|
(33, HardTanh), \
|
||||||
(34, SoftSign), \
|
(34, SoftSign), \
|
||||||
(35, ELU), \
|
|
||||||
(36, HardSigmoid), \
|
(36, HardSigmoid), \
|
||||||
(37, RationalTanh) ,\
|
(37, RationalTanh) ,\
|
||||||
(38, RectifiedTanh) ,\
|
(38, RectifiedTanh) ,\
|
||||||
|
@ -211,6 +209,8 @@
|
||||||
(4, ReverseDivide),\
|
(4, ReverseDivide),\
|
||||||
(5, ReverseSubtract),\
|
(5, ReverseSubtract),\
|
||||||
(6, MaxPairwise),\
|
(6, MaxPairwise),\
|
||||||
|
(7, ELU), \
|
||||||
|
(8, ELUDerivative), \
|
||||||
(13, MinPairwise),\
|
(13, MinPairwise),\
|
||||||
(14, CopyPws),\
|
(14, CopyPws),\
|
||||||
(15, Mod),\
|
(15, Mod),\
|
||||||
|
|
|
@ -25,12 +25,14 @@
|
||||||
#include <ops/declarable/helpers/legacy_helpers.h>
|
#include <ops/declarable/helpers/legacy_helpers.h>
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CONFIGURABLE_OP_IMPL(elu, 1, 1, true, 0, 0) {
|
CONFIGURABLE_OP_IMPL(elu, 1, 1, true, -2, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
input->applyTransform(nd4j::transform::ELU, output, nullptr);
|
const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f;
|
||||||
STORE_RESULT(output);
|
|
||||||
|
input->applyScalar(nd4j::scalar::ELU, alpha, output);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -41,14 +43,18 @@ namespace nd4j {
|
||||||
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||||
}
|
}
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(elu_bp, 2, 1, true, 0, 0) {
|
CONFIGURABLE_OP_IMPL(elu_bp, 2, 1, true, -2, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto epsilon = INPUT_VARIABLE(1);
|
auto epsilon = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f;
|
||||||
|
|
||||||
|
// input->applyPairwiseTransform(pairwise::ELUDerivativeE, epsilon, output);
|
||||||
|
helpers::eluDerivative(block.launchContext(), input, epsilon, output, alpha);
|
||||||
|
|
||||||
//input->applyPairwiseTransform(pairwise::ELUDerivativeE, epsilon, z, nullptr);
|
|
||||||
helpers::eluDerivative(block.launchContext(), input, epsilon, z);
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,13 +25,13 @@
|
||||||
#include <ops/declarable/helpers/legacy_helpers.h>
|
#include <ops/declarable/helpers/legacy_helpers.h>
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
CONFIGURABLE_OP_IMPL(lrelu, 1, 1, true, 1, 0) {
|
CONFIGURABLE_OP_IMPL(lrelu, 1, 1, true, -2, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
float t = block.numT() > 0 ? T_ARG(0) : 0.0f;
|
float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f;
|
||||||
|
|
||||||
input->applyScalar(nd4j::scalar::LeakyRELU, t, output);
|
input->applyScalar(nd4j::scalar::LeakyRELU, alpha, output);
|
||||||
STORE_RESULT(output);
|
STORE_RESULT(output);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -43,14 +43,16 @@ namespace nd4j {
|
||||||
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||||
}
|
}
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(lrelu_bp, 2, 1, true, 0, 0) {
|
CONFIGURABLE_OP_IMPL(lrelu_bp, 2, 1, true, -2, 0) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto epsilon = INPUT_VARIABLE(1);
|
auto epsilon = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f;
|
||||||
|
|
||||||
//input->applyPairwiseTransform(pairwise::LRELUDerivativeE, epsilon, z, nullptr);
|
//input->applyPairwiseTransform(pairwise::LRELUDerivativeE, epsilon, z, nullptr);
|
||||||
helpers::leakyReluDerivative(block.launchContext(), input, epsilon, z);
|
helpers::leakyReluDerivative(block.launchContext(), input, epsilon, z, alpha);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -82,8 +82,8 @@ namespace nd4j {
|
||||||
* Math is: x < 0 ? alpha * x : x;
|
* Math is: x < 0 ? alpha * x : x;
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_lrelu)
|
#if NOT_EXCLUDED(OP_lrelu)
|
||||||
DECLARE_CONFIGURABLE_OP(lrelu, 1, 1, true, 0, 0);
|
DECLARE_CONFIGURABLE_OP(lrelu, 1, 1, true, -2, 0);
|
||||||
DECLARE_CONFIGURABLE_OP(lrelu_bp, 2, 1, true, 0, 0);
|
DECLARE_CONFIGURABLE_OP(lrelu_bp, 2, 1, true, -2, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -91,8 +91,8 @@ namespace nd4j {
|
||||||
* Math is: x >= 0 ? x : exp(x) - 1;
|
* Math is: x >= 0 ? x : exp(x) - 1;
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_elu)
|
#if NOT_EXCLUDED(OP_elu)
|
||||||
DECLARE_CONFIGURABLE_OP(elu, 1, 1, true, 0, 0);
|
DECLARE_CONFIGURABLE_OP(elu, 1, 1, true, -2, 0);
|
||||||
DECLARE_CONFIGURABLE_OP(elu_bp, 2, 1, true, 0, 0);
|
DECLARE_CONFIGURABLE_OP(elu_bp, 2, 1, true, -2, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -81,29 +81,35 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) {
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
return x >= (T)0.f? y : T(0.f);
|
const T alphaT = static_cast<T>(alpha);
|
||||||
|
|
||||||
|
auto functor = LAMBDA_TT(x, y, alphaT) {
|
||||||
|
return x < 0 ? alphaT * y : y;
|
||||||
};
|
};
|
||||||
|
|
||||||
input->applyPairwiseLambda<T>(epsilon, functor, output);
|
input->applyPairwiseLambda<T>(epsilon, functor, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) {
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
static void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) {
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
return y * nd4j::math::nd4j_eluderivative<T,T>(x);
|
const T alphaT = static_cast<T>(alpha);
|
||||||
|
|
||||||
|
auto functor = LAMBDA_TT(x, y, alphaT){
|
||||||
|
return y * nd4j::math::nd4j_eluderivative<T,T>(x, alphaT);
|
||||||
};
|
};
|
||||||
|
|
||||||
input->applyPairwiseLambda<T>(epsilon, functor, output);
|
input->applyPairwiseLambda<T>(epsilon, functor, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) {
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -30,7 +30,7 @@ namespace nd4j {
|
||||||
|
|
||||||
shared[threadIdx.x] = 0;
|
shared[threadIdx.x] = 0;
|
||||||
|
|
||||||
|
// each thread will compare 2 elements: E and E+1
|
||||||
for (int e = tid; e < length - 1; e += blockDim.x * gridDim.x) {
|
for (int e = tid; e < length - 1; e += blockDim.x * gridDim.x) {
|
||||||
auto val0 = x[shape::getIndexOffset(e, xShapeInfo, length)];
|
auto val0 = x[shape::getIndexOffset(e, xShapeInfo, length)];
|
||||||
auto val1 = x[shape::getIndexOffset(e+1, xShapeInfo, length)];
|
auto val1 = x[shape::getIndexOffset(e+1, xShapeInfo, length)];
|
||||||
|
@ -41,11 +41,12 @@ namespace nd4j {
|
||||||
else
|
else
|
||||||
v = val1 >= val0;
|
v = val1 >= val0;
|
||||||
|
|
||||||
|
// store comparison result in shared memory
|
||||||
shared[threadIdx.x] += v ? 0 : 1;
|
shared[threadIdx.x] += v ? 0 : 1;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
// aggregate sum
|
// aggregate sums in shared memory
|
||||||
for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
|
for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
|
||||||
if (threadIdx.x < activeThreads)
|
if (threadIdx.x < activeThreads)
|
||||||
shared[threadIdx.x] += shared[threadIdx.x + activeThreads];
|
shared[threadIdx.x] += shared[threadIdx.x + activeThreads];
|
||||||
|
@ -53,7 +54,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// store over the grid
|
// store over the grid if we have more than 1 block
|
||||||
if (gridDim.x > 1) {
|
if (gridDim.x > 1) {
|
||||||
|
|
||||||
auto tc = reinterpret_cast<unsigned int *>(reductionBuffer);
|
auto tc = reinterpret_cast<unsigned int *>(reductionBuffer);
|
||||||
|
@ -96,7 +97,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
// if we have only 1 block, we just store results right away
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
auto tc = reinterpret_cast<unsigned int*>(reductionBuffer);
|
auto tc = reinterpret_cast<unsigned int*>(reductionBuffer);
|
||||||
tc[16384] = 0;
|
tc[16384] = 0;
|
||||||
|
|
|
@ -424,7 +424,7 @@ static __global__ void avgPooling2dCuda(const void *vx, const Nd4jLong *xShapeIn
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
int tid = blockIdx.x * gridDim.x + threadIdx.x;
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
for (int index = tid; index < length; index += blockDim.x * gridDim.x) {
|
for (int index = tid; index < length; index += blockDim.x * gridDim.x) {
|
||||||
|
|
||||||
|
@ -519,7 +519,7 @@ static __global__ void pnormPooling2dCuda(const void *vx, const Nd4jLong *xShape
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
int tid = blockIdx.x * gridDim.x + threadIdx.x;
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
for (int index = tid; index < length; index += blockDim.x * gridDim.x) {
|
for (int index = tid; index < length; index += blockDim.x * gridDim.x) {
|
||||||
|
|
||||||
|
@ -610,7 +610,7 @@ static __global__ void maxPooling2dCuda(const void *vx, const Nd4jLong *xShapeIn
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
int tid = blockIdx.x * gridDim.x + threadIdx.x;
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
for (int index = tid; index < length; index += blockDim.x * gridDim.x) {
|
for (int index = tid; index < length; index += blockDim.x * gridDim.x) {
|
||||||
|
|
||||||
|
@ -908,7 +908,7 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf
|
||||||
/*** max ***/
|
/*** max ***/
|
||||||
case 0: {
|
case 0: {
|
||||||
coord2 = hstart;
|
coord2 = hstart;
|
||||||
coord3 = hend;
|
coord3 = wstart;
|
||||||
|
|
||||||
T max = -DataTypeUtils::max<T>();
|
T max = -DataTypeUtils::max<T>();
|
||||||
for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) {
|
for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) {
|
||||||
|
@ -923,8 +923,9 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf
|
||||||
}
|
}
|
||||||
coords[2] = coord2;
|
coords[2] = coord2;
|
||||||
coords[3] = coord3;
|
coords[3] = coord3;
|
||||||
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank)], y[yOffset]);
|
auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[zOffset], y[yOffset]);
|
||||||
|
//z[zOffset] += y[yOffset];
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
@ -987,7 +988,7 @@ void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& i
|
||||||
|
|
||||||
PointersManager manager(block.launchContext(), "pooling2dBP");
|
PointersManager manager(block.launchContext(), "pooling2dBP");
|
||||||
|
|
||||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
const int threadsPerBlock = 256;
|
||||||
const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||||
|
|
||||||
|
|
|
@ -39,7 +39,7 @@ static __global__ void diagFunctorKernel(void* outputBuffer, Nd4jLong* outputSha
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto step = gridDim.x * blockDim.x;
|
const auto step = gridDim.x * blockDim.x;
|
||||||
for (int t = tid; t < inputLength; t += step) {
|
for (int t = tid; t < inputLength; t += step) {
|
||||||
z[shape::getIndexOffset(t * (inputLength + 1), outputShape, outputLength)] = x[shape::getIndexOffset(t, inputShape, inputLength)]; //tX];
|
z[shape::getIndexOffset(t * (inputLength + 1), outputShape, outputLength)] = x[shape::getIndexOffset(t, inputShape, inputLength)]; //tX];
|
||||||
|
@ -59,7 +59,7 @@ static __global__ void diagFunctorKernel(void* outputBuffer, Nd4jLong* outputSha
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto step = gridDim.x * blockDim.x;
|
const auto step = gridDim.x * blockDim.x;
|
||||||
Nd4jLong i = threadIdx.x * (outputLength + 1);
|
Nd4jLong i = threadIdx.x * (outputLength + 1);
|
||||||
for (int t = tid; t < outputLength && i < inputLength; t += step) {
|
for (int t = tid; t < outputLength && i < inputLength; t += step) {
|
||||||
|
|
|
@ -35,9 +35,11 @@ namespace helpers {
|
||||||
T const* input = reinterpret_cast<T const*>(inputBuf);
|
T const* input = reinterpret_cast<T const*>(inputBuf);
|
||||||
T* output = reinterpret_cast<T*>(outputBuf);
|
T* output = reinterpret_cast<T*>(outputBuf);
|
||||||
|
|
||||||
|
// trivial idea: loop through all elements, get independent probability for each element to be nullified
|
||||||
for (Nd4jLong e = 0; e < inLen; ++e) {
|
for (Nd4jLong e = 0; e < inLen; ++e) {
|
||||||
T val = nodeRng->relativeT(e, T(0.f), T(1.f));
|
T val = nodeRng->relativeT(e, T(0.f), T(1.f));
|
||||||
|
|
||||||
|
// if probability is ok - we're saving scaled value
|
||||||
if (double(val) < probVal)
|
if (double(val) < probVal)
|
||||||
output[shape::getIndexOffset(e, outputShape, inLen)] = T(input[shape::getIndexOffset(e, inputShape, inLen)] / probVal);
|
output[shape::getIndexOffset(e, outputShape, inLen)] = T(input[shape::getIndexOffset(e, inputShape, inLen)] / probVal);
|
||||||
}
|
}
|
||||||
|
@ -80,7 +82,7 @@ namespace helpers {
|
||||||
std::vector<Nd4jLong> dims(reduceShape->lengthOf());
|
std::vector<Nd4jLong> dims(reduceShape->lengthOf());
|
||||||
reduceShape->syncToHost(); // to ensure that follows are actual
|
reduceShape->syncToHost(); // to ensure that follows are actual
|
||||||
bool fit = true;
|
bool fit = true;
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit))
|
|
||||||
for( int i = 0; i < dims.size(); i++ ) {
|
for( int i = 0; i < dims.size(); i++ ) {
|
||||||
if (fit) {
|
if (fit) {
|
||||||
dims[i] = reduceShape->e<Nd4jLong>(i);
|
dims[i] = reduceShape->e<Nd4jLong>(i);
|
||||||
|
@ -96,8 +98,7 @@ namespace helpers {
|
||||||
REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank.");
|
REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank.");
|
||||||
std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext()));
|
std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext()));
|
||||||
chunk->assign(1.f);
|
chunk->assign(1.f);
|
||||||
//chunk->applyRandom<randomOps::DropOutInverted<T>>(rng, nullptr, chunk.get(), &probValue);
|
|
||||||
//NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob);
|
|
||||||
dropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), probValue, seed);
|
dropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), probValue, seed);
|
||||||
// broadcast chunk to full matrix
|
// broadcast chunk to full matrix
|
||||||
std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
|
std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
|
||||||
|
@ -105,6 +106,7 @@ namespace helpers {
|
||||||
|
|
||||||
*dropOutMultiplier += *chunk;
|
*dropOutMultiplier += *chunk;
|
||||||
|
|
||||||
|
// FIXME: we could do this in one step, aren't we?
|
||||||
output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
|
output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,8 +115,11 @@ namespace helpers {
|
||||||
|
|
||||||
int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
||||||
auto xType = input->dataType();
|
auto xType = input->dataType();
|
||||||
|
NDArray::prepareSpecialUse({output}, {input});
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({output}, {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
/////////////////////////////////// backrpopagations ///////////////////////////////////////////////
|
/////////////////////////////////// backrpopagations ///////////////////////////////////////////////
|
||||||
|
@ -136,6 +141,8 @@ namespace helpers {
|
||||||
|
|
||||||
for (int e = tid; e < len; e += step) {
|
for (int e = tid; e < len; e += step) {
|
||||||
const auto zOffset = shape::getIndexOffset(e, outputShape, len);
|
const auto zOffset = shape::getIndexOffset(e, outputShape, len);
|
||||||
|
|
||||||
|
// if probability was non-zero on FF step, we'll scale grads back
|
||||||
if (output[zOffset] != T(0.))
|
if (output[zOffset] != T(0.))
|
||||||
output[zOffset] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue);
|
output[zOffset] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue);
|
||||||
|
|
||||||
|
@ -143,12 +150,17 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
|
||||||
|
// we're making additional FF run to see how probabilities played out with given seeds
|
||||||
int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue);
|
int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue);
|
||||||
auto stream = context.launchContext()->getCudaStream();
|
auto stream = context.launchContext()->getCudaStream();
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({output}, {input, gradOut});
|
||||||
|
|
||||||
if (ND4J_STATUS_OK == res)
|
if (ND4J_STATUS_OK == res)
|
||||||
dropoutBPKernel<T><<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue);
|
dropoutBPKernel<T><<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({output}, {input, gradOut});
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -239,6 +251,7 @@ namespace helpers {
|
||||||
|
|
||||||
int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta);
|
int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta);
|
||||||
if (res == ND4J_STATUS_OK) {
|
if (res == ND4J_STATUS_OK) {
|
||||||
|
// FIXME: can we make it single-loop?
|
||||||
(*output) *= alpha;
|
(*output) *= alpha;
|
||||||
(*output) *= (*gradOut); //->applyPairwiseTransform<transform::Multiply>(gradOut, output, nullptr);
|
(*output) *= (*gradOut); //->applyPairwiseTransform<transform::Multiply>(gradOut, output, nullptr);
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
// we run things in blocks, 1 partition per block of threads
|
||||||
for (Nd4jLong o = blockIdx.x; o < numOutputs; o += gridDim.x) {
|
for (Nd4jLong o = blockIdx.x; o < numOutputs; o += gridDim.x) {
|
||||||
auto z = reinterpret_cast<X*>(vz[o]);
|
auto z = reinterpret_cast<X*>(vz[o]);
|
||||||
|
|
||||||
|
@ -89,9 +89,11 @@ namespace nd4j {
|
||||||
auto x = reinterpret_cast<X*>(vx);
|
auto x = reinterpret_cast<X*>(vx);
|
||||||
auto indices = reinterpret_cast<Y*>(vindices);
|
auto indices = reinterpret_cast<Y*>(vindices);
|
||||||
|
|
||||||
|
// we run things in blocks, 1 partition per block of threads
|
||||||
for (int i = blockIdx.x; i < numOutputs; i += gridDim.x) {
|
for (int i = blockIdx.x; i < numOutputs; i += gridDim.x) {
|
||||||
auto z = reinterpret_cast<X*>(vz[i]);
|
auto z = reinterpret_cast<X*>(vz[i]);
|
||||||
|
|
||||||
|
// each thread has own counter for partitions
|
||||||
int outCnt = 0;
|
int outCnt = 0;
|
||||||
|
|
||||||
for (Nd4jLong e = 0; e < iLength; e++) {
|
for (Nd4jLong e = 0; e < iLength; e++) {
|
||||||
|
@ -145,6 +147,7 @@ namespace nd4j {
|
||||||
tadOffsets[i] = packZ.platformOffsets();
|
tadOffsets[i] = packZ.platformOffsets();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// we copy pointers to device
|
||||||
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
|
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
|
||||||
auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *)));
|
auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *)));
|
||||||
auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *)));
|
auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *)));
|
||||||
|
@ -248,6 +251,7 @@ namespace nd4j {
|
||||||
indicesShapes[e] = indices.at(e)->getSpecialShapeInfo();
|
indicesShapes[e] = indices.at(e)->getSpecialShapeInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copying pointers to buffers to device
|
||||||
auto dInputBuffers = reinterpret_cast<void **>(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *)));
|
auto dInputBuffers = reinterpret_cast<void **>(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *)));
|
||||||
auto dIndicesBuffers = reinterpret_cast<void **>(pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *)));
|
auto dIndicesBuffers = reinterpret_cast<void **>(pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *)));
|
||||||
auto dInputShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(inputShapes.data(), inputSize * sizeof(Nd4jLong *)));
|
auto dInputShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(inputShapes.data(), inputSize * sizeof(Nd4jLong *)));
|
||||||
|
@ -283,6 +287,7 @@ namespace nd4j {
|
||||||
inputTadOffsets[e] = packX.platformOffsets();
|
inputTadOffsets[e] = packX.platformOffsets();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copying pointers to buffers to device
|
||||||
auto dInputBuffers = reinterpret_cast<void **>(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *)));
|
auto dInputBuffers = reinterpret_cast<void **>(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *)));
|
||||||
auto dInputTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(inputTadShapes.data(), inputSize * sizeof(Nd4jLong *)));
|
auto dInputTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(inputTadShapes.data(), inputSize * sizeof(Nd4jLong *)));
|
||||||
auto dInputTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(inputTadOffsets.data(), inputSize * sizeof(Nd4jLong *)));
|
auto dInputTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(inputTadOffsets.data(), inputSize * sizeof(Nd4jLong *)));
|
||||||
|
@ -313,6 +318,7 @@ namespace nd4j {
|
||||||
|
|
||||||
NDArray::registerSpecialUse({}, {indices, input});
|
NDArray::registerSpecialUse({}, {indices, input});
|
||||||
|
|
||||||
|
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
|
||||||
for (auto v:outputList) {
|
for (auto v:outputList) {
|
||||||
v->tickWriteDevice();
|
v->tickWriteDevice();
|
||||||
}
|
}
|
||||||
|
|
|
@ -29,6 +29,7 @@ namespace nd4j {
|
||||||
|
|
||||||
Nd4jLong xCoord[MAX_RANK];
|
Nd4jLong xCoord[MAX_RANK];
|
||||||
|
|
||||||
|
// each block of threads works on 1 input array
|
||||||
for (Nd4jLong e = blockIdx.x; e < numInputs; e += gridDim.x) {
|
for (Nd4jLong e = blockIdx.x; e < numInputs; e += gridDim.x) {
|
||||||
auto z = reinterpret_cast<T*>(zBuffer) + offsets[e];
|
auto z = reinterpret_cast<T*>(zBuffer) + offsets[e];
|
||||||
|
|
||||||
|
@ -39,6 +40,7 @@ namespace nd4j {
|
||||||
auto xRank = shape::rank(xShapeInfo);
|
auto xRank = shape::rank(xShapeInfo);
|
||||||
auto xLength = shape::length(xShapeInfo);
|
auto xLength = shape::length(xShapeInfo);
|
||||||
|
|
||||||
|
// each element of this input array has own place within common output array
|
||||||
for (uint i = threadIdx.x; i < xLength; i += blockDim.x) {
|
for (uint i = threadIdx.x; i < xLength; i += blockDim.x) {
|
||||||
shape::index2coords(xRank, xShape, i, xLength, xCoord, order);
|
shape::index2coords(xRank, xShape, i, xLength, xCoord, order);
|
||||||
auto xOffset = shape::getOffset(0, xShape, xStride, xCoord, xRank);
|
auto xOffset = shape::getOffset(0, xShape, xStride, xCoord, xRank);
|
||||||
|
@ -65,6 +67,7 @@ namespace nd4j {
|
||||||
hdShapes[e] = inputs[e]->specialShapeInfo();
|
hdShapes[e] = inputs[e]->specialShapeInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copying pointers to device
|
||||||
auto dBuffers = (void **) pm.replicatePointer(hdBuffers.data(), inputs.size() * sizeof(void*));
|
auto dBuffers = (void **) pm.replicatePointer(hdBuffers.data(), inputs.size() * sizeof(void*));
|
||||||
auto dShapes = (Nd4jLong **)pm.replicatePointer(hdShapes.data(), inputs.size() * sizeof(Nd4jLong*));
|
auto dShapes = (Nd4jLong **)pm.replicatePointer(hdShapes.data(), inputs.size() * sizeof(Nd4jLong*));
|
||||||
auto dOffsets = (Nd4jLong *) pm.replicatePointer(hOffsets.data(), inputs.size() * sizeof(Nd4jLong));
|
auto dOffsets = (Nd4jLong *) pm.replicatePointer(hOffsets.data(), inputs.size() * sizeof(Nd4jLong));
|
||||||
|
@ -76,6 +79,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void flatten(nd4j::LaunchContext *context, std::vector<NDArray*> &inputs, NDArray *output, char order) {
|
void flatten(nd4j::LaunchContext *context, std::vector<NDArray*> &inputs, NDArray *output, char order) {
|
||||||
|
// FIXME: we want NDArrayFactory::prepareSpecialUse here eventually
|
||||||
for (auto v:inputs)
|
for (auto v:inputs)
|
||||||
v->syncToDevice();
|
v->syncToDevice();
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void applyGradientDescent_(LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) {
|
void applyGradientDescent_(LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) {
|
||||||
|
// classic one
|
||||||
auto lambda = LAMBDA_TT(_x, _y, weight) {
|
auto lambda = LAMBDA_TT(_x, _y, weight) {
|
||||||
return _x - (_y * weight);
|
return _x - (_y * weight);
|
||||||
};
|
};
|
||||||
|
|
|
@ -44,6 +44,7 @@ namespace nd4j {
|
||||||
|
|
||||||
X binSize = X((*max_val - *min_val) / numBins);
|
X binSize = X((*max_val - *min_val) / numBins);
|
||||||
|
|
||||||
|
// nullify bins
|
||||||
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
|
||||||
bins[e] = (Z) 0;
|
bins[e] = (Z) 0;
|
||||||
}
|
}
|
||||||
|
@ -53,14 +54,12 @@ namespace nd4j {
|
||||||
int idx = int((dx[e] - *min_val) / binSize);
|
int idx = int((dx[e] - *min_val) / binSize);
|
||||||
idx = math::nd4j_max(idx, 0); //atomicMax(&idx, 0);//atomicMax(&idx, 0);
|
idx = math::nd4j_max(idx, 0); //atomicMax(&idx, 0);//atomicMax(&idx, 0);
|
||||||
idx = math::nd4j_min(idx, int(numBins - 1)); //atomicMin(&idx, int(numBins - 1));
|
idx = math::nd4j_min(idx, int(numBins - 1)); //atomicMin(&idx, int(numBins - 1));
|
||||||
nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1);
|
nd4j::math::atomics::nd4j_atomicAdd<Z>(&bins[idx], (Z)1);
|
||||||
// bins[idx]++;
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
// at this point all bins in shared memory are calculated, so we aggregate them now via threadfence trick
|
||||||
|
|
||||||
// transfer shared memory to reduction memory
|
// transfer shared memory to reduction memory
|
||||||
|
|
||||||
|
|
||||||
if (gridDim.x > 1) {
|
if (gridDim.x > 1) {
|
||||||
unsigned int *tc = (unsigned int *)reductionPointer;
|
unsigned int *tc = (unsigned int *)reductionPointer;
|
||||||
__shared__ bool amLast;
|
__shared__ bool amLast;
|
||||||
|
|
|
@ -64,6 +64,7 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray*
|
||||||
|
|
||||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), copy.data(), copy.size());
|
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), copy.data(), copy.size());
|
||||||
|
|
||||||
|
// we launch legacy IndexMax op, to get indices of max values along dimension
|
||||||
auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions);
|
auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions);
|
||||||
|
|
||||||
dim3 launchDims(256, 256, 16384);
|
dim3 launchDims(256, 256, 16384);
|
||||||
|
|
|
@ -66,29 +66,35 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) {
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
return x >= (T)0.f? y : T(0.f);
|
const T alphaT = static_cast<T>(alpha);
|
||||||
|
|
||||||
|
auto functor = LAMBDA_TT(x, y, alphaT) {
|
||||||
|
return x < 0 ? alphaT * y : y;
|
||||||
};
|
};
|
||||||
|
|
||||||
input->applyPairwiseLambda(epsilon, functor, output);
|
input->applyPairwiseLambda(epsilon, functor, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) {
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) {
|
linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) {
|
||||||
auto functor = LAMBDA_TT(x, y){
|
|
||||||
return y * nd4j::math::nd4j_eluderivative<T,T>(x);
|
const T alphaT = static_cast<T>(alpha);
|
||||||
|
|
||||||
|
auto functor = LAMBDA_TT(x, y, alphaT){
|
||||||
|
return y * nd4j::math::nd4j_eluderivative<T,T>(x, alphaT);
|
||||||
};
|
};
|
||||||
|
|
||||||
input->applyPairwiseLambda(epsilon, functor, output);
|
input->applyPairwiseLambda(epsilon, functor, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) {
|
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) {
|
||||||
BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -41,12 +41,12 @@ namespace helpers {
|
||||||
const T tbeta = static_cast<T>(beta);
|
const T tbeta = static_cast<T>(beta);
|
||||||
const T talpha = static_cast<T>(alpha);
|
const T talpha = static_cast<T>(alpha);
|
||||||
|
|
||||||
|
// one block of threads processes 1 example within batch
|
||||||
for (uint i = blockIdx.x; i < numTads; i += gridDim.x) {
|
for (uint i = blockIdx.x; i < numTads; i += gridDim.x) {
|
||||||
auto x = reinterpret_cast<T*>(vx) + xTadOffsets[i];
|
auto x = reinterpret_cast<T*>(vx) + xTadOffsets[i];
|
||||||
auto z = reinterpret_cast<T*>(vz) + zTadOffsets[i];
|
auto z = reinterpret_cast<T*>(vz) + zTadOffsets[i];
|
||||||
|
|
||||||
// load everything into shared memory
|
// load everything into shared memory, so we'll operate on shared memory from now on
|
||||||
shared[threadIdx.x] = x[threadIdx.x * xEws];
|
shared[threadIdx.x] = x[threadIdx.x * xEws];
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
@ -94,7 +94,7 @@ namespace helpers {
|
||||||
sharedY[threadIdx.x] = 0.f;
|
sharedY[threadIdx.x] = 0.f;
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
// we're operating in shared memory
|
||||||
for (int s = begin; s < end; s++)
|
for (int s = begin; s < end; s++)
|
||||||
sharedY[threadIdx.x] = sharedY[threadIdx.x] + sharedX[s] * sharedX[s];
|
sharedY[threadIdx.x] = sharedY[threadIdx.x] + sharedX[s] * sharedX[s];
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
|
@ -37,7 +37,7 @@ namespace nd4j {
|
||||||
static __global__ void global_mergeMaxIndex_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
static __global__ void global_mergeMaxIndex_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
||||||
auto output = reinterpret_cast<Z*>(voutput);
|
auto output = reinterpret_cast<Z*>(voutput);
|
||||||
|
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto step = gridDim.x * blockDim.x;
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < length; e += step) {
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
@ -81,7 +81,13 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void mergeMaxIndex(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
void mergeMaxIndex(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
NDArray::prepareSpecialUse({&output}, {});
|
||||||
|
for (auto v:inArrs)
|
||||||
|
v->syncToDevice();
|
||||||
|
|
||||||
BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INDEXING_TYPES);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&output}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -90,7 +96,7 @@ namespace nd4j {
|
||||||
static __global__ void global_mergeMax_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
static __global__ void global_mergeMax_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
||||||
auto output = reinterpret_cast<T*>(voutput);
|
auto output = reinterpret_cast<T*>(voutput);
|
||||||
|
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto step = gridDim.x * blockDim.x;
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < length; e += step) {
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
@ -131,7 +137,12 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void mergeMax(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
void mergeMax(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
NDArray::prepareSpecialUse({&output}, {});
|
||||||
|
for (auto v:inArrs)
|
||||||
|
v->syncToDevice();
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&output}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -139,7 +150,7 @@ namespace nd4j {
|
||||||
static __global__ void global_mergeAvg_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
static __global__ void global_mergeAvg_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
||||||
auto output = reinterpret_cast<T*>(voutput);
|
auto output = reinterpret_cast<T*>(voutput);
|
||||||
|
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto step = gridDim.x * blockDim.x;
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < length; e += step) {
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
@ -178,7 +189,13 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void mergeAvg(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
void mergeAvg(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
NDArray::prepareSpecialUse({&output}, {});
|
||||||
|
for (auto v:inArrs)
|
||||||
|
v->syncToDevice();
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), FLOAT_TYPES);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&output}, {});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -186,7 +203,7 @@ namespace nd4j {
|
||||||
static __global__ void global_mergeAdd_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
static __global__ void global_mergeAdd_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
|
||||||
auto output = reinterpret_cast<T*>(voutput);
|
auto output = reinterpret_cast<T*>(voutput);
|
||||||
|
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto step = gridDim.x * blockDim.x;
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < length; e += step) {
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
@ -226,7 +243,13 @@ namespace nd4j {
|
||||||
BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), NUMERIC_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), NUMERIC_TYPES);
|
||||||
|
|
||||||
void mergeAdd(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
void mergeAdd(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
|
||||||
|
NDArray::prepareSpecialUse({&output}, {});
|
||||||
|
for (auto v:inArrs)
|
||||||
|
v->syncToDevice();
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), NUMERIC_TYPES);
|
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), NUMERIC_TYPES);
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&output}, {});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,18 +31,18 @@ namespace helpers {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void fillUpElementKernel(void* outputBuffer, Nd4jLong* outputShapeInfo, void* inputBuffer, Nd4jLong* inputShapeInfo, Nd4jLong* pTadShape, Nd4jLong* pTadOffsets, Nd4jLong n) {
|
static __global__ void fillUpElementKernel(void* outputBuffer, Nd4jLong* outputShapeInfo, void* inputBuffer, Nd4jLong* inputShapeInfo, Nd4jLong* pTadShape, Nd4jLong* pTadOffsets, Nd4jLong n) {
|
||||||
__shared__ T *z, *x;
|
|
||||||
__shared__ Nd4jLong bufferLength, arrLen;
|
__shared__ Nd4jLong bufferLength, arrLen;
|
||||||
|
|
||||||
|
auto z = reinterpret_cast<T*>(outputBuffer);
|
||||||
|
auto x = reinterpret_cast<T*>(inputBuffer);
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
z = reinterpret_cast<T*>(outputBuffer);
|
|
||||||
x = reinterpret_cast<T*>(inputBuffer);
|
|
||||||
arrLen = shape::length(pTadShape);
|
arrLen = shape::length(pTadShape);
|
||||||
bufferLength = shape::length(outputShapeInfo);
|
bufferLength = shape::length(outputShapeInfo);
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto step = gridDim.x * blockDim.x;
|
const auto step = gridDim.x * blockDim.x;
|
||||||
for (int t = tid; t < bufferLength; t += step) {
|
for (int t = tid; t < bufferLength; t += step) {
|
||||||
auto tX = x + pTadOffsets[t];
|
auto tX = x + pTadOffsets[t];
|
||||||
|
@ -77,8 +77,6 @@ namespace helpers {
|
||||||
// manager.synchronize();
|
// manager.synchronize();
|
||||||
sortedVals.tickWriteDevice();
|
sortedVals.tickWriteDevice();
|
||||||
sortedVals.syncToHost();
|
sortedVals.syncToHost();
|
||||||
sortedVals.printIndexedBuffer("Hello");
|
|
||||||
sortedVals.printBuffer("Hello line");
|
|
||||||
auto stream = context->getCudaStream();
|
auto stream = context->getCudaStream();
|
||||||
fillUpElementKernel<T><<<32, 64, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), pTadShape, pTadOffsets, n);
|
fillUpElementKernel<T><<<32, 64, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), pTadShape, pTadOffsets, n);
|
||||||
}
|
}
|
||||||
|
|
|
@ -74,17 +74,14 @@ static void polyGammaCudaLauncher(const int blocksPerGrid, const int threadsPerB
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& z) {
|
void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& z) {
|
||||||
|
|
||||||
if(!n.isActualOnDeviceSide()) n.syncToDevice();
|
NDArray::prepareSpecialUse({&z}, {&n, &x});
|
||||||
if(!x.isActualOnDeviceSide()) x.syncToDevice();
|
|
||||||
|
|
||||||
int threadsPerBlock = MAX_NUM_THREADS;
|
int threadsPerBlock = MAX_NUM_THREADS;
|
||||||
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);
|
||||||
|
|
||||||
n.tickReadHost();
|
NDArray::registerSpecialUse({&z}, {&n, &x});
|
||||||
x.tickReadHost();
|
|
||||||
z.tickWriteDevice();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void polyGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vn, const Nd4jLong *nShapeInfo, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void polyGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vn, const Nd4jLong *nShapeInfo, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES);
|
||||||
|
|
|
@ -28,7 +28,7 @@ namespace helpers {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static __global__ void global_range(void *output, Nd4jLong length, T start, T delta) {
|
static __global__ void global_range(void *output, Nd4jLong length, T start, T delta) {
|
||||||
auto buff = reinterpret_cast<T*>(output);
|
auto buff = reinterpret_cast<T*>(output);
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto step = gridDim.x * blockDim.x;
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
for(Nd4jLong i = tid; i < length; i += step)
|
for(Nd4jLong i = tid; i < length; i += step)
|
||||||
|
@ -43,10 +43,11 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
void range(nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) {
|
void range(nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) {
|
||||||
|
NDArray::prepareSpecialUse({&outVector}, {&start, &delta});
|
||||||
BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (context, start, delta, outVector), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (context, start, delta, outVector), LIBND4J_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&outVector}, {&start, &delta});
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void _range, (nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector), NUMERIC_TYPES);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -26,13 +26,11 @@ namespace nd4j {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
template<typename T>
|
template<typename T>
|
||||||
void toggle_bits__(NDArray &in, NDArray &out) {
|
void toggle_bits__(NDArray &in, NDArray &out) {
|
||||||
NDArray::prepareSpecialUse({&out}, {&in});
|
|
||||||
auto lambda = LAMBDA_T(_x) {
|
auto lambda = LAMBDA_T(_x) {
|
||||||
return ~_x;//eUtils::flip_bits(_x);
|
return ~_x;//eUtils::flip_bits(_x);
|
||||||
};
|
};
|
||||||
|
|
||||||
in.applyLambda(lambda, &out);
|
in.applyLambda(lambda, &out);
|
||||||
NDArray::registerSpecialUse({&out}, {&in});
|
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES);
|
||||||
|
|
||||||
|
|
|
@ -906,7 +906,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr
|
||||||
linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1;
|
linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1;
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
const auto tid = blockIdx.x * gridDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const auto step = gridDim.x * blockDim.x;
|
const auto step = gridDim.x * blockDim.x;
|
||||||
|
|
||||||
for (Nd4jLong e = tid; e < length; e += step) {
|
for (Nd4jLong e = tid; e < length; e += step) {
|
||||||
|
|
|
@ -46,8 +46,8 @@ namespace helpers {
|
||||||
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond);
|
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond);
|
||||||
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
|
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
|
||||||
void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
|
void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
|
||||||
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
|
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha);
|
||||||
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
|
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha);
|
||||||
void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
|
void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
|
||||||
void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
|
void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
|
||||||
void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
|
void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
|
||||||
|
|
|
@ -2271,26 +2271,26 @@ namespace simdOps {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename X>
|
template <typename X, typename Y, typename Z>
|
||||||
class ELU {
|
class ELU {
|
||||||
public:
|
public:
|
||||||
no_op_exec_special_same
|
no_op_exec_special_same
|
||||||
no_op_exec_special_same_cuda
|
no_op_exec_special_same_cuda
|
||||||
|
|
||||||
op_def static X op(X d1, X *params) {
|
op_def static Z op(X d1, Y d2, Z *params) {
|
||||||
return nd4j::math::nd4j_elu<X,X>(d1);
|
return nd4j::math::nd4j_elu<X,Z>(d1, static_cast<X>(d2));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
template <typename X>
|
template <typename X, typename Y, typename Z>
|
||||||
class ELUDerivative {
|
class ELUDerivative {
|
||||||
public:
|
public:
|
||||||
no_op_exec_special_same
|
no_op_exec_special_same
|
||||||
no_op_exec_special_same_cuda
|
no_op_exec_special_same_cuda
|
||||||
|
|
||||||
op_def static X op(X d1, X *params) {
|
op_def static Z op(X d1, Y d2, Z *params) {
|
||||||
return nd4j::math::nd4j_eluderivative<X,X>(d1);
|
return nd4j::math::nd4j_eluderivative<X,Z>(d1, static_cast<X>(d2));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -130,13 +130,12 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, typename Z>
|
template<typename T, typename Z>
|
||||||
math_def inline Z nd4j_elu(T val) {
|
math_def inline Z nd4j_elu(T val, T alpha) {
|
||||||
if (val >= (T) 0.f) return val;
|
if (val >= (T) 0.f)
|
||||||
else return nd4j_exp<T, Z>(val) - (Z) 1.0f;
|
return val;
|
||||||
//return val >= 0.0 ? val : (nd4j_exp<T>(val) - 1.0);
|
return static_cast<Z>(alpha) * (nd4j_exp<T, Z>(val) - static_cast<Z>(1.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename T, typename Z>
|
template<typename T, typename Z>
|
||||||
math_def inline Z nd4j_leakyrelu(T val,T alpha) {
|
math_def inline Z nd4j_leakyrelu(T val,T alpha) {
|
||||||
if (val < (T) 0.0f)
|
if (val < (T) 0.0f)
|
||||||
|
@ -145,13 +144,14 @@ namespace nd4j {
|
||||||
return val;
|
return val;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
template<typename T, typename Z>
|
template<typename T, typename Z>
|
||||||
math_def inline Z nd4j_eluderivative(T val) {
|
math_def inline Z nd4j_eluderivative(T val, T alpha) {
|
||||||
if (val >= (T) 0.0f) return (Z) 1.0f;
|
if (val >= static_cast<T>(0.0f))
|
||||||
else return nd4j_exp<T, Z>(val);
|
return static_cast<Z>(1.0f);
|
||||||
|
return static_cast<Z>(alpha) * nd4j_exp<T, Z>(val);
|
||||||
//return val >= 0.0 ? 1.0 : nd4j_exp(val);
|
//return val >= 0.0 ? 1.0 : nd4j_exp(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T, typename Z>
|
template<typename T, typename Z>
|
||||||
math_def inline Z nd4j_sin(T val);
|
math_def inline Z nd4j_sin(T val);
|
||||||
|
|
||||||
|
@ -904,13 +904,13 @@ inline __device__ int16_t nd4j_atomicMax<int16_t>(int16_t* address, int16_t val)
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline __device__ float16 nd4j_atomicMax<float16>(float16* address, float16 val) {
|
inline __device__ float16 nd4j_atomicMax<float16>(float16* address, float16 val) {
|
||||||
int* address_as_ull = (int*) address;
|
auto address_as_ull = (int*) address;
|
||||||
|
|
||||||
long addr = (long) address;
|
long addr = (long) address;
|
||||||
bool misaligned = addr & 0x3;
|
bool misaligned = addr & 0x3;
|
||||||
|
|
||||||
if (misaligned)
|
if (misaligned)
|
||||||
address_as_ull = (int *) (addr - 2);
|
address_as_ull = (int *) (address - 1);
|
||||||
|
|
||||||
PAIR old, assumed, fresh;
|
PAIR old, assumed, fresh;
|
||||||
|
|
||||||
|
@ -937,13 +937,13 @@ inline __device__ float16 nd4j_atomicMax<float16>(float16* address, float16 val)
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline __device__ bfloat16 nd4j_atomicMax<bfloat16>(bfloat16* address, bfloat16 val) {
|
inline __device__ bfloat16 nd4j_atomicMax<bfloat16>(bfloat16* address, bfloat16 val) {
|
||||||
int* address_as_ull = (int*) address;
|
auto address_as_ull = (int*) address;
|
||||||
|
|
||||||
long addr = (long)(address);
|
long addr = (long)(address);
|
||||||
bool misaligned = addr & 0x3;
|
bool misaligned = addr & 0x3;
|
||||||
|
|
||||||
if (misaligned)
|
if (misaligned)
|
||||||
address_as_ull = (int *) (addr - 2);
|
address_as_ull = (int *) (address - 1);
|
||||||
|
|
||||||
BPAIR old, assumed, fresh;
|
BPAIR old, assumed, fresh;
|
||||||
|
|
||||||
|
@ -1060,13 +1060,13 @@ inline __device__ float16 nd4j_atomicAdd<float16>(float16* address, float16 val)
|
||||||
#if __CUDA_ARCH__ >= 700
|
#if __CUDA_ARCH__ >= 700
|
||||||
atomicAdd(reinterpret_cast<__half*>(address), val.data);
|
atomicAdd(reinterpret_cast<__half*>(address), val.data);
|
||||||
#else
|
#else
|
||||||
int* address_as_ull = (int*) address;
|
auto address_as_ull = (int*) address;
|
||||||
|
|
||||||
long addr = (long) address;
|
long addr = (long) address;
|
||||||
bool misaligned = addr & 0x3;
|
bool misaligned = addr & 0x3;
|
||||||
|
|
||||||
if (misaligned)
|
if (misaligned)
|
||||||
address_as_ull = (int *) (addr - 2);
|
address_as_ull = (int *) (address - 1);
|
||||||
|
|
||||||
PAIR old, assumed, fresh;
|
PAIR old, assumed, fresh;
|
||||||
|
|
||||||
|
@ -1094,13 +1094,13 @@ inline __device__ float16 nd4j_atomicAdd<float16>(float16* address, float16 val)
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline __device__ bfloat16 nd4j_atomicAdd<bfloat16>(bfloat16* address, bfloat16 val) {
|
inline __device__ bfloat16 nd4j_atomicAdd<bfloat16>(bfloat16* address, bfloat16 val) {
|
||||||
int* address_as_ull = (int*) address;
|
auto address_as_ull = (int*) address;
|
||||||
|
|
||||||
long addr = (long)(address);
|
auto addr = (long)(address);
|
||||||
bool misaligned = addr & 0x3;
|
bool misaligned = addr & 0x3;
|
||||||
|
|
||||||
if (misaligned)
|
if (misaligned)
|
||||||
address_as_ull = (int *) (addr - 2);
|
address_as_ull = (int *) (address - 1);
|
||||||
|
|
||||||
BPAIR old, assumed, fresh;
|
BPAIR old, assumed, fresh;
|
||||||
|
|
||||||
|
@ -1367,13 +1367,13 @@ inline __device__ Nd4jLong nd4j_atomicMul<Nd4jLong>(Nd4jLong* address, Nd4jLong
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline __device__ bfloat16 nd4j_atomicMul<bfloat16>(bfloat16* address, bfloat16 val) {
|
inline __device__ bfloat16 nd4j_atomicMul<bfloat16>(bfloat16* address, bfloat16 val) {
|
||||||
int* address_as_ull = (int*) address;
|
auto address_as_ull = (int*) address;
|
||||||
|
|
||||||
long addr = (long)(address);
|
long addr = (long)(address);
|
||||||
bool misaligned = addr & 0x3;
|
bool misaligned = addr & 0x3;
|
||||||
|
|
||||||
if (misaligned)
|
if (misaligned)
|
||||||
address_as_ull = (int *) (addr - 2);
|
address_as_ull = (int *) (address - 1);
|
||||||
|
|
||||||
BPAIR old, assumed, fresh;
|
BPAIR old, assumed, fresh;
|
||||||
|
|
||||||
|
@ -1400,13 +1400,13 @@ inline __device__ bfloat16 nd4j_atomicMul<bfloat16>(bfloat16* address, bfloat16
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val) {
|
inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val) {
|
||||||
int* address_as_ull = (int*) address;
|
auto address_as_ull = (int*) address;
|
||||||
|
|
||||||
long addr = (long)(address);
|
long addr = (long)(address);
|
||||||
bool misaligned = addr & 0x3;
|
bool misaligned = addr & 0x3;
|
||||||
|
|
||||||
if (misaligned)
|
if (misaligned)
|
||||||
address_as_ull = (int *) (addr - 2);
|
address_as_ull = (int *) (address - 1);
|
||||||
|
|
||||||
BPAIR old, assumed, fresh;
|
BPAIR old, assumed, fresh;
|
||||||
|
|
||||||
|
|
|
@ -905,6 +905,25 @@ TEST_F(DeclarableOpsTests12, softmax_9) {
|
||||||
delete arrF;
|
delete arrF;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests12, maxpool_bp_half_1) {
|
||||||
|
auto x = NDArrayFactory::create<bfloat16>('c', {2, 3, 10, 1}, {0.2019043f, 0.6464844f, 0.9116211f, 0.60058594f, 0.34033203f, 0.7036133f, 0.6772461f, 0.3815918f, 0.87353516f, 0.04650879f, 0.67822266f, 0.8618164f, 0.88378906f, 0.7573242f, 0.66796875f, 0.63427734f, 0.33764648f, 0.46923828f, 0.62939453f, 0.76464844f, -0.8618164f, -0.94873047f, -0.9902344f, -0.88916016f, -0.86572266f, -0.92089844f, -0.90722656f, -0.96533203f, -0.97509766f, -0.4975586f, -0.84814453f, -0.984375f, -0.98828125f, -0.95458984f, -0.9472656f, -0.91064453f, -0.80859375f, -0.83496094f, -0.9140625f, -0.82470703f, 0.4802246f, 0.45361328f, 0.28125f, 0.28320312f, 0.79345703f, 0.44604492f, -0.30273438f, 0.11730957f, 0.56396484f, 0.73583984f, 0.1418457f, -0.44848633f, 0.6923828f, -0.40234375f, 0.40185547f, 0.48632812f, 0.14538574f, 0.4638672f, 0.13000488f, 0.5058594f});
|
||||||
|
auto y = NDArrayFactory::create<bfloat16>('c', {2, 3, 10, 1}, {0.0f, -0.13391113f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, -0.1751709f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.51904297f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5107422f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
|
||||||
|
auto z = NDArrayFactory::create<bfloat16>('c', {2, 3, 10, 1});
|
||||||
|
|
||||||
|
nd4j::ops::maxpool2d_bp op;
|
||||||
|
Context ctx(1);
|
||||||
|
Nd4jLong iArgs[] = {5,1,1, 2,2,0, 1,1,1, 0,0};
|
||||||
|
ctx.setIArguments(iArgs, 11);
|
||||||
|
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
|
||||||
|
ctx.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo());
|
||||||
|
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
|
||||||
|
|
||||||
|
|
||||||
|
auto status = op.execute(&ctx);
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests12, lrn_bp_1) {
|
TEST_F(DeclarableOpsTests12, lrn_bp_1) {
|
||||||
|
|
||||||
|
|
|
@ -2794,42 +2794,33 @@ TEST_F(DeclarableOpsTests3, svd_test11) {
|
||||||
TEST_F(DeclarableOpsTests3, elu_test1) {
|
TEST_F(DeclarableOpsTests3, elu_test1) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {3,3}, {0.1, .2, .3, -.4,-.5,-.6, .7, .8, .9});
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {0.1, .2, .3, -.4,-.5,-.6, .7, .8, .9});
|
||||||
// auto expS = NDArrayFactory::create<double>('c', {3});
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9});
|
||||||
// auto expU = NDArrayFactory::create<double>('c', {3,3});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,3}, {.1, .2, .3, -0.32968, -0.393469, -0.451188, .7, .8, .9});
|
|
||||||
|
|
||||||
nd4j::ops::elu op;
|
nd4j::ops::elu op;
|
||||||
auto results = op.execute({&x}, {}, {});
|
auto results = op.execute({&x}, {0.5}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
auto s = results->at(0);
|
auto s = results->at(0);
|
||||||
// auto u = results->at(1);
|
|
||||||
// auto v = results->at(2);
|
|
||||||
// s->printIndexedBuffer("ELU");
|
|
||||||
ASSERT_TRUE(exp.equalsTo(s));
|
ASSERT_TRUE(exp.equalsTo(s));
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests3, elu_test2) {
|
TEST_F(DeclarableOpsTests3, elu_bp_test1) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9});
|
auto x = NDArrayFactory::create<double>('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9});
|
||||||
auto eps = NDArrayFactory::create<double>('c', {3,3});
|
auto eps = NDArrayFactory::create<double>('c', {3,3});
|
||||||
eps.assign(2.);
|
eps.assign(2.);
|
||||||
// auto expU = NDArrayFactory::create<double>('c', {3,3});
|
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 1.34064, 1.213061, 1.097623, 2, 2, 2});
|
|
||||||
|
|
||||||
nd4j::ops::elu_bp op;
|
nd4j::ops::elu_bp op;
|
||||||
auto results = op.execute({ &x, &eps }, {}, {});
|
auto results = op.execute({ &x, &eps }, {0.5}, {});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
auto s = results->at(0);
|
auto s = results->at(0);
|
||||||
// auto u = results->at(1);
|
|
||||||
// auto v = results->at(2);
|
|
||||||
// s->printIndexedBuffer("ELU_BP");
|
|
||||||
ASSERT_TRUE(exp.equalsTo(s));
|
ASSERT_TRUE(exp.equalsTo(s));
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
@ -2839,8 +2830,6 @@ TEST_F(DeclarableOpsTests3, elu_test2) {
|
||||||
TEST_F(DeclarableOpsTests3, lrelu_test1) {
|
TEST_F(DeclarableOpsTests3, lrelu_test1) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
||||||
// auto expS = NDArrayFactory::create<double>('c', {3});
|
|
||||||
// auto expU = NDArrayFactory::create<double>('c', {3,3});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9});
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9});
|
||||||
|
|
||||||
nd4j::ops::lrelu op;
|
nd4j::ops::lrelu op;
|
||||||
|
@ -2849,20 +2838,16 @@ TEST_F(DeclarableOpsTests3, lrelu_test1) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
auto s = results->at(0);
|
auto s = results->at(0);
|
||||||
// auto u = results->at(1);
|
|
||||||
// auto v = results->at(2);
|
|
||||||
// s->printIndexedBuffer("LRELU");
|
|
||||||
ASSERT_TRUE(exp.equalsTo(s));
|
ASSERT_TRUE(exp.equalsTo(s));
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests3, lrelu_test2) {
|
TEST_F(DeclarableOpsTests3, lrelu_bp_test1) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
||||||
// auto expS = NDArrayFactory::create<double>('c', {3});
|
|
||||||
auto eps = NDArrayFactory::create<double>('c', {3,3}, {2,2,2,2,2,2,2, 2,2});
|
auto eps = NDArrayFactory::create<double>('c', {3,3}, {2,2,2,2,2,2,2, 2,2});
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2, 2, 2, 0, 0, 0, 2, 2, 2});
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2});
|
||||||
|
|
||||||
nd4j::ops::lrelu_bp op;
|
nd4j::ops::lrelu_bp op;
|
||||||
auto results = op.execute({&x, &eps}, {0.2}, {});
|
auto results = op.execute({&x, &eps}, {0.2}, {});
|
||||||
|
@ -2870,9 +2855,6 @@ TEST_F(DeclarableOpsTests3, lrelu_test2) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
auto s = results->at(0);
|
auto s = results->at(0);
|
||||||
// auto u = results->at(1);
|
|
||||||
// auto v = results->at(2);
|
|
||||||
// s->printIndexedBuffer("LRELU_BP");
|
|
||||||
ASSERT_TRUE(exp.equalsTo(s));
|
ASSERT_TRUE(exp.equalsTo(s));
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
@ -2882,8 +2864,6 @@ TEST_F(DeclarableOpsTests3, lrelu_test2) {
|
||||||
TEST_F(DeclarableOpsTests3, selu_test1) {
|
TEST_F(DeclarableOpsTests3, selu_test1) {
|
||||||
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
|
||||||
// auto expS = NDArrayFactory::create<double>('c', {3});
|
|
||||||
// auto expU = NDArrayFactory::create<double>('c', {3,3});
|
|
||||||
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309});
|
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309});
|
||||||
|
|
||||||
nd4j::ops::selu op;
|
nd4j::ops::selu op;
|
||||||
|
@ -2892,7 +2872,6 @@ TEST_F(DeclarableOpsTests3, selu_test1) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
auto s = results->at(0);
|
auto s = results->at(0);
|
||||||
// s->printIndexedBuffer("SELU");
|
|
||||||
ASSERT_TRUE(exp.equalsTo(s));
|
ASSERT_TRUE(exp.equalsTo(s));
|
||||||
|
|
||||||
delete results;
|
delete results;
|
||||||
|
|
|
@ -2761,7 +2761,7 @@ TEST_F(DeclarableOpsTests5, ELU_1) {
|
||||||
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.});
|
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.});
|
||||||
auto res = NDArrayFactory::create<double>('c', {2, 2, 2});
|
auto res = NDArrayFactory::create<double>('c', {2, 2, 2});
|
||||||
|
|
||||||
input.applyTransform(transform::ELU, &res);
|
input.applyScalar(nd4j::scalar::ELU, 1.f, &res);
|
||||||
|
|
||||||
ASSERT_TRUE(res.equalsTo(&exp));
|
ASSERT_TRUE(res.equalsTo(&exp));
|
||||||
}
|
}
|
||||||
|
|
|
@ -204,20 +204,29 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast;
|
import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
|
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
|
||||||
|
@ -1126,10 +1135,26 @@ public class DifferentialFunctionFactory {
|
||||||
return new org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative(sameDiff(), iX, wrt).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative(sameDiff(), iX, wrt).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable tanhRationalBp(SDVariable in, SDVariable epsilon) {
|
||||||
|
return new RationalTanhBp(sameDiff(), in, epsilon).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable tanhRectifiedBp(SDVariable in, SDVariable epsilon) {
|
||||||
|
return new RectifiedTanhBp(sameDiff(), in, epsilon).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use {@link #tanhRationalBp(SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
public SDVariable tanhRationalDerivative(SDVariable in) {
|
public SDVariable tanhRationalDerivative(SDVariable in) {
|
||||||
return new RationalTanhDerivative(sameDiff(), in, false).outputVariable();
|
return new RationalTanhDerivative(sameDiff(), in, false).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Use {@link #tanhRectifiedBp(SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
public SDVariable tanhRectifiedDerivative(SDVariable in) {
|
public SDVariable tanhRectifiedDerivative(SDVariable in) {
|
||||||
return new RectifiedTanhDerivative(sameDiff(), in, false).outputVariable();
|
return new RectifiedTanhDerivative(sameDiff(), in, false).outputVariable();
|
||||||
}
|
}
|
||||||
|
@ -1280,6 +1305,14 @@ public class DifferentialFunctionFactory {
|
||||||
return new Cube(sameDiff(), iX, false).outputVariable();
|
return new Cube(sameDiff(), iX, false).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable cubeBp(SDVariable in, SDVariable epsilon) {
|
||||||
|
return new CubeBp(sameDiff(), in, epsilon).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Use {@link #cubeBp(SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
public SDVariable cubeDerivative(SDVariable iX) {
|
public SDVariable cubeDerivative(SDVariable iX) {
|
||||||
return new CubeDerivative(sameDiff(), iX, false).outputVariable();
|
return new CubeDerivative(sameDiff(), iX, false).outputVariable();
|
||||||
}
|
}
|
||||||
|
@ -1329,6 +1362,14 @@ public class DifferentialFunctionFactory {
|
||||||
return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable();
|
return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable thresholdRelu(SDVariable in, SDVariable epsilon, double cutoff){
|
||||||
|
return new ThresholdRelu(sameDiff(), in, cutoff).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
public SDVariable thresholdReluBp(SDVariable in, SDVariable epsilon, double cutoff){
|
||||||
|
return new ThresholdReluBp(sameDiff(), in, epsilon, cutoff).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
public SDVariable relu6(SDVariable iX, double cutoff) {
|
public SDVariable relu6(SDVariable iX, double cutoff) {
|
||||||
return new Relu6(sameDiff(), iX, false, cutoff).outputVariable();
|
return new Relu6(sameDiff(), iX, false, cutoff).outputVariable();
|
||||||
}
|
}
|
||||||
|
@ -1350,6 +1391,14 @@ public class DifferentialFunctionFactory {
|
||||||
return new HardTanh(sameDiff(), iX, false).outputVariable();
|
return new HardTanh(sameDiff(), iX, false).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable hardTanhBp(SDVariable in, SDVariable epsilon) {
|
||||||
|
return new HardTanhBp(sameDiff(), in, epsilon).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Use {@link #hardTanhBp(SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
public SDVariable hardTanhDerivative(SDVariable iX) {
|
public SDVariable hardTanhDerivative(SDVariable iX) {
|
||||||
return new HardTanhDerivative(sameDiff(), iX, false).outputVariable();
|
return new HardTanhDerivative(sameDiff(), iX, false).outputVariable();
|
||||||
}
|
}
|
||||||
|
@ -1358,6 +1407,9 @@ public class DifferentialFunctionFactory {
|
||||||
return new HardSigmoid(sameDiff(), in, false).outputVariable();
|
return new HardSigmoid(sameDiff(), in, false).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable hardSigmoidBp(SDVariable in, SDVariable epsilon){
|
||||||
|
return new HardSigmoidBp(sameDiff(), in, epsilon).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
public SDVariable sigmoid(SDVariable iX) {
|
public SDVariable sigmoid(SDVariable iX) {
|
||||||
return new Sigmoid(sameDiff(), iX, false).outputVariable();
|
return new Sigmoid(sameDiff(), iX, false).outputVariable();
|
||||||
|
@ -1486,10 +1538,16 @@ public class DifferentialFunctionFactory {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable softsignBp(SDVariable in, SDVariable epsilon) {
|
||||||
|
return new SoftSignBp(sameDiff(), in, epsilon).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Use {@link #softsignBp(SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
public SDVariable softsignDerivative(SDVariable iX) {
|
public SDVariable softsignDerivative(SDVariable iX) {
|
||||||
return new SoftSignDerivative(sameDiff(), iX, false).outputVariable();
|
return new SoftSignDerivative(sameDiff(), iX, false).outputVariable();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1500,14 +1558,12 @@ public class DifferentialFunctionFactory {
|
||||||
|
|
||||||
|
|
||||||
public SDVariable elu(SDVariable iX) {
|
public SDVariable elu(SDVariable iX) {
|
||||||
return new ELU(sameDiff(), iX, false).outputVariable();
|
return new ELU(sameDiff(), iX).outputVariable();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable eluBp(SDVariable in, SDVariable epsilon) {
|
||||||
public SDVariable eluDerivative(SDVariable iX) {
|
return new EluBp(sameDiff(), in, epsilon).outputVariable();
|
||||||
return new ELUDerivative(sameDiff(), iX, false).outputVariable();
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -1516,6 +1572,14 @@ public class DifferentialFunctionFactory {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable leakyReluBp(SDVariable in, SDVariable epsilon, double cutoff) {
|
||||||
|
return new LeakyReLUBp(sameDiff(), in, epsilon, cutoff).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Use {@link #leakyReluBp(SDVariable, SDVariable, double)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
public SDVariable leakyReluDerivative(SDVariable iX, double cutoff) {
|
public SDVariable leakyReluDerivative(SDVariable iX, double cutoff) {
|
||||||
return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable();
|
return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable();
|
||||||
}
|
}
|
||||||
|
@ -1832,7 +1896,15 @@ public class DifferentialFunctionFactory {
|
||||||
return new SELU(sameDiff(), arg, false).outputVariable();
|
return new SELU(sameDiff(), arg, false).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SDVariable seluBp(SDVariable in, SDVariable epsilon) {
|
||||||
|
validateDifferentialFunctionsameDiff(in);
|
||||||
|
return new SeluBp(sameDiff(), in, epsilon).outputVariable();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @deprecated Use {@link #seluBp(SDVariable, SDVariable)}
|
||||||
|
*/
|
||||||
|
@Deprecated
|
||||||
public SDVariable seluDerivative(SDVariable arg) {
|
public SDVariable seluDerivative(SDVariable arg) {
|
||||||
validateDifferentialFunctionsameDiff(arg);
|
validateDifferentialFunctionsameDiff(arg);
|
||||||
return new SELUDerivative(sameDiff(), arg, false).outputVariable();
|
return new SELUDerivative(sameDiff(), arg, false).outputVariable();
|
||||||
|
|
|
@ -163,31 +163,6 @@ public class SDNN extends SDOps {
|
||||||
return updateVariableNameAndReference(result, name);
|
return updateVariableNameAndReference(result, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input.
|
|
||||||
* {@link #elu(SDVariable)}
|
|
||||||
*
|
|
||||||
* @param x Input variable
|
|
||||||
* @return Output variable
|
|
||||||
*/
|
|
||||||
public SDVariable eluDerivative(SDVariable x) {
|
|
||||||
return eluDerivative(null, x);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input.
|
|
||||||
* {@link #elu(SDVariable)}
|
|
||||||
*
|
|
||||||
* @param name Output variable name
|
|
||||||
* @param x Input variable
|
|
||||||
* @return Output variable
|
|
||||||
*/
|
|
||||||
public SDVariable eluDerivative(String name, SDVariable x) {
|
|
||||||
validateFloatingPoint("eluDerivative", x);
|
|
||||||
SDVariable result = f().eluDerivative(x);
|
|
||||||
return updateVariableNameAndReference(result, name);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* GELU activation function - Gaussian Error Linear Units<br>
|
* GELU activation function - Gaussian Error Linear Units<br>
|
||||||
* For more details, see <i>Gaussian Error Linear Units (GELUs)</i> - <a href="https://arxiv.org/abs/1606.08415">https://arxiv.org/abs/1606.08415</a>
|
* For more details, see <i>Gaussian Error Linear Units (GELUs)</i> - <a href="https://arxiv.org/abs/1606.08415">https://arxiv.org/abs/1606.08415</a>
|
||||||
|
|
|
@ -255,8 +255,6 @@ public class LegacyOpMapper {
|
||||||
return Abs.class;
|
return Abs.class;
|
||||||
case 2:
|
case 2:
|
||||||
return LogSoftMax.class;
|
return LogSoftMax.class;
|
||||||
case 3:
|
|
||||||
return ELUDerivative.class;
|
|
||||||
case 4:
|
case 4:
|
||||||
return org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class;
|
return org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class;
|
||||||
case 5:
|
case 5:
|
||||||
|
|
|
@ -881,7 +881,6 @@ public class OpValidation {
|
||||||
SoftmaxBp.class,
|
SoftmaxBp.class,
|
||||||
|
|
||||||
CubeDerivative.class,
|
CubeDerivative.class,
|
||||||
ELUDerivative.class,
|
|
||||||
GELUDerivative.class,
|
GELUDerivative.class,
|
||||||
PreciseGELUDerivative.class,
|
PreciseGELUDerivative.class,
|
||||||
HardSigmoidDerivative.class,
|
HardSigmoidDerivative.class,
|
||||||
|
@ -901,6 +900,17 @@ public class OpValidation {
|
||||||
TanhDerivative.class,
|
TanhDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class,
|
||||||
PowDerivative.class,
|
PowDerivative.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class,
|
||||||
|
|
||||||
BiasAddGrad.class,
|
BiasAddGrad.class,
|
||||||
ConcatBp.class,
|
ConcatBp.class,
|
||||||
|
|
|
@ -229,6 +229,7 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class,
|
org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class,
|
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
|
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.Relu6.class,
|
org.nd4j.linalg.api.ops.impl.scalar.Relu6.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class,
|
org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class,
|
||||||
org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class,
|
org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class,
|
||||||
|
@ -421,7 +422,6 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class,
|
org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative.class,
|
||||||
|
@ -433,6 +433,17 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative.class,
|
org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative.class,
|
||||||
org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError.class,
|
org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError.class,
|
||||||
|
|
|
@ -21,6 +21,8 @@ import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
|
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -42,9 +44,9 @@ public class ActivationCube extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(@NonNull INDArray in, @NonNull INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(@NonNull INDArray in, @NonNull INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new CubeDerivative(in));
|
Nd4j.getExecutioner().execAndReturn(new CubeBp(in, epsilon, in));
|
||||||
dLdz.muli(epsilon);
|
|
||||||
return new Pair<>(dLdz, null);
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -18,11 +18,11 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||||
|
@ -57,7 +57,7 @@ public class ActivationELU extends BaseActivationFunction {
|
||||||
public INDArray getActivation(INDArray in, boolean training) {
|
public INDArray getActivation(INDArray in, boolean training) {
|
||||||
// no support in ELU native to override alpha
|
// no support in ELU native to override alpha
|
||||||
if (this.alpha != 1.00) {
|
if (this.alpha != 1.00) {
|
||||||
INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()));
|
INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0];
|
||||||
alphaMultiple.muli(alpha);
|
alphaMultiple.muli(alpha);
|
||||||
BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0));
|
BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0));
|
||||||
} else {
|
} else {
|
||||||
|
@ -74,21 +74,8 @@ public class ActivationELU extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
// no support in ELU native to override alpha
|
Nd4j.getExecutioner().execAndReturn(new EluBp(in, epsilon, in));
|
||||||
if (alpha != 1.00) {
|
return new Pair<>(in, null);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new ELUDerivative(in.dup()));
|
|
||||||
dLdz.muli(alpha);
|
|
||||||
BooleanIndexing.replaceWhere(dLdz, 1, Conditions.equals(alpha));
|
|
||||||
|
|
||||||
dLdz.muli(epsilon);
|
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
else {
|
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new ELUDerivative(in));
|
|
||||||
dLdz.muli(epsilon);
|
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -41,9 +43,9 @@ public class ActivationHardSigmoid extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(in));
|
Nd4j.getExecutioner().execAndReturn(new HardSigmoidBp(in, epsilon, in));
|
||||||
dLdz.muli(epsilon);
|
|
||||||
return new Pair<>(dLdz, null);
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -43,9 +45,10 @@ public class ActivationHardTanH extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new HardTanhDerivative(in));
|
|
||||||
dLdz.muli(epsilon);
|
Nd4j.getExecutioner().execAndReturn(new HardTanhBp(in, epsilon, in));
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
|
@ -54,9 +56,10 @@ public class ActivationLReLU extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new LeakyReLUDerivative(in, alpha));
|
|
||||||
dLdz.muli(epsilon);
|
Nd4j.getExecutioner().execAndReturn(new LeakyReLUBp(in, epsilon, in, alpha));
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -63,7 +63,7 @@ public class ActivationRReLU extends BaseActivationFunction {
|
||||||
public INDArray getActivation(INDArray in, boolean training) {
|
public INDArray getActivation(INDArray in, boolean training) {
|
||||||
if (training) {
|
if (training) {
|
||||||
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
|
||||||
this.alpha = Nd4j.rand(in.shape(), l, u, Nd4j.getRandom());
|
this.alpha = Nd4j.rand(l, u, Nd4j.getRandom(), in.shape());
|
||||||
}
|
}
|
||||||
INDArray inTimesAlpha = in.mul(alpha);
|
INDArray inTimesAlpha = in.mul(alpha);
|
||||||
BooleanIndexing.replaceWhere(in, inTimesAlpha, Conditions.lessThan(0));
|
BooleanIndexing.replaceWhere(in, inTimesAlpha, Conditions.lessThan(0));
|
||||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -48,9 +50,10 @@ public class ActivationRationalTanh extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new RationalTanhDerivative(in));
|
|
||||||
dLdz.muli(epsilon);
|
Nd4j.getExecutioner().execAndReturn(new RationalTanhBp(in, epsilon, in));
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -20,8 +20,10 @@ import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
|
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
import org.nd4j.linalg.api.ops.impl.scalar.Step;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
@ -41,9 +43,10 @@ public class ActivationReLU6 extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new Step(in));
|
|
||||||
dLdz.muli(epsilon);
|
Nd4j.getExecutioner().execAndReturn(new Relu6Derivative(in, epsilon, in));
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -45,9 +47,10 @@ public class ActivationRectifiedTanh extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new RectifiedTanhDerivative(in));
|
|
||||||
dLdz.muli(epsilon);
|
Nd4j.getExecutioner().execAndReturn(new RectifiedTanhBp(in, epsilon, in));
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -41,9 +43,10 @@ public class ActivationSELU extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new SELUDerivative(in));
|
|
||||||
dLdz.muli(epsilon);
|
Nd4j.getExecutioner().execAndReturn(new SeluBp(in, epsilon, in));
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -18,7 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -41,9 +42,10 @@ public class ActivationSigmoid extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new SigmoidDerivative(in));
|
|
||||||
dLdz.muli(epsilon);
|
Nd4j.getExecutioner().execAndReturn(new SigmoidDerivative(in, epsilon, in));
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
|
@ -41,9 +43,10 @@ public class ActivationSoftPlus extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new Sigmoid(in));
|
|
||||||
dLdz.muli(epsilon);
|
Nd4j.getExecutioner().execAndReturn(new SoftPlusBp(in, epsilon, in));
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -41,9 +43,10 @@ public class ActivationSoftSign extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new SoftSignDerivative(in));
|
|
||||||
dLdz.muli(epsilon);
|
Nd4j.getExecutioner().execAndReturn(new SoftSignBp(in, epsilon, in));
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -21,7 +21,9 @@ import lombok.Getter;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
|
|
||||||
|
@ -42,10 +44,10 @@ public class ActivationSoftmax extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray out = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, in.ulike()))[0];
|
|
||||||
INDArray x = out.mul(epsilon).sum(1);
|
Nd4j.getExecutioner().execAndReturn(new SoftmaxBp(in, epsilon, in, -1));
|
||||||
INDArray dLdz = out.mul(epsilon.subColumnVector(x));
|
|
||||||
return new Pair<>(dLdz, null);
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -18,11 +18,11 @@ package org.nd4j.linalg.activations.impl;
|
||||||
|
|
||||||
import lombok.EqualsAndHashCode;
|
import lombok.EqualsAndHashCode;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative;
|
||||||
import org.nd4j.linalg.primitives.Pair;
|
import org.nd4j.linalg.primitives.Pair;
|
||||||
import org.nd4j.linalg.activations.BaseActivationFunction;
|
import org.nd4j.linalg.activations.BaseActivationFunction;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
|
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -41,9 +41,10 @@ public class ActivationTanH extends BaseActivationFunction {
|
||||||
@Override
|
@Override
|
||||||
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
|
||||||
assertShape(in, epsilon);
|
assertShape(in, epsilon);
|
||||||
INDArray dLdz = Nd4j.getExecutioner().exec(new TanhDerivative(in));
|
|
||||||
dLdz.muli(epsilon);
|
Nd4j.getExecutioner().execAndReturn(new TanhDerivative(in, epsilon, in));
|
||||||
return new Pair<>(dLdz, null);
|
|
||||||
|
return new Pair<>(in, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -2836,156 +2836,78 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return putScalar(i, element.getDouble(0));
|
return putScalar(i, element.getDouble(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param columnVector the column vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray diviColumnVector(INDArray columnVector) {
|
public INDArray diviColumnVector(INDArray columnVector) {
|
||||||
validateNumericalArray("diviColumnVector", false);
|
validateNumericalArray("diviColumnVector", false);
|
||||||
return doColumnWise(columnVector, 'd');
|
return doColumnWise(columnVector, 'd');
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param columnVector the column vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray divColumnVector(INDArray columnVector) {
|
public INDArray divColumnVector(INDArray columnVector) {
|
||||||
validateNumericalArray("divColumnVector", false);
|
validateNumericalArray("divColumnVector", false);
|
||||||
return dup().diviColumnVector(columnVector);
|
return dup().diviColumnVector(columnVector);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param rowVector the row vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray diviRowVector(INDArray rowVector) {
|
public INDArray diviRowVector(INDArray rowVector) {
|
||||||
validateNumericalArray("diviRowVector", false);
|
validateNumericalArray("diviRowVector", false);
|
||||||
return doRowWise(rowVector, 'd');
|
return doRowWise(rowVector, 'd');
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param rowVector the row vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray divRowVector(INDArray rowVector) {
|
public INDArray divRowVector(INDArray rowVector) {
|
||||||
validateNumericalArray("divRowVector", false);
|
validateNumericalArray("divRowVector", false);
|
||||||
return dup().diviRowVector(rowVector);
|
return dup().diviRowVector(rowVector);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param columnVector the column vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray muliColumnVector(INDArray columnVector) {
|
public INDArray muliColumnVector(INDArray columnVector) {
|
||||||
validateNumericalArray("muliColumnVector", false);
|
validateNumericalArray("muliColumnVector", false);
|
||||||
return doColumnWise(columnVector, 'm');
|
return doColumnWise(columnVector, 'm');
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param columnVector the column vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mulColumnVector(INDArray columnVector) {
|
public INDArray mulColumnVector(INDArray columnVector) {
|
||||||
validateNumericalArray("mulColumnVector", false);
|
validateNumericalArray("mulColumnVector", false);
|
||||||
return dup().muliColumnVector(columnVector);
|
return dup().muliColumnVector(columnVector);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param rowVector the row vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray muliRowVector(INDArray rowVector) {
|
public INDArray muliRowVector(INDArray rowVector) {
|
||||||
validateNumericalArray("muliRowVector", false);
|
validateNumericalArray("muliRowVector", false);
|
||||||
return doRowWise(rowVector, 'm');
|
return doRowWise(rowVector, 'm');
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param rowVector the row vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mulRowVector(INDArray rowVector) {
|
public INDArray mulRowVector(INDArray rowVector) {
|
||||||
validateNumericalArray("mulRowVector", false);
|
validateNumericalArray("mulRowVector", false);
|
||||||
return dup().muliRowVector(rowVector);
|
return dup().muliRowVector(rowVector);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param columnVector the column vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray subiColumnVector(INDArray columnVector) {
|
public INDArray subiColumnVector(INDArray columnVector) {
|
||||||
validateNumericalArray("subiColumnVector", false);
|
validateNumericalArray("subiColumnVector", false);
|
||||||
return doColumnWise(columnVector, 's');
|
return doColumnWise(columnVector, 's');
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param columnVector the column vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray subColumnVector(INDArray columnVector) {
|
public INDArray subColumnVector(INDArray columnVector) {
|
||||||
validateNumericalArray("subColumnVector", false);
|
validateNumericalArray("subColumnVector", false);
|
||||||
return dup().subiColumnVector(columnVector);
|
return dup().subiColumnVector(columnVector);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param rowVector the row vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray subiRowVector(INDArray rowVector) {
|
public INDArray subiRowVector(INDArray rowVector) {
|
||||||
validateNumericalArray("subiRowVector", false);
|
validateNumericalArray("subiRowVector", false);
|
||||||
return doRowWise(rowVector, 's');
|
return doRowWise(rowVector, 's');
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param rowVector the row vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray subRowVector(INDArray rowVector) {
|
public INDArray subRowVector(INDArray rowVector) {
|
||||||
validateNumericalArray("subRowVector", false);
|
validateNumericalArray("subRowVector", false);
|
||||||
return dup().subiRowVector(rowVector);
|
return dup().subiRowVector(rowVector);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param columnVector the column vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray addiColumnVector(INDArray columnVector) {
|
public INDArray addiColumnVector(INDArray columnVector) {
|
||||||
validateNumericalArray("addiColumnVector", false);
|
validateNumericalArray("addiColumnVector", false);
|
||||||
|
@ -2997,24 +2919,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return doColumnWise(columnVector, 'p');
|
return doColumnWise(columnVector, 'p');
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param columnVector the column vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray addColumnVector(INDArray columnVector) {
|
public INDArray addColumnVector(INDArray columnVector) {
|
||||||
validateNumericalArray("addColumnVector", false);
|
validateNumericalArray("addColumnVector", false);
|
||||||
return dup().addiColumnVector(columnVector);
|
return dup().addiColumnVector(columnVector);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param rowVector the row vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray addiRowVector(INDArray rowVector) {
|
public INDArray addiRowVector(INDArray rowVector) {
|
||||||
validateNumericalArray("addiRowVector", false);
|
validateNumericalArray("addiRowVector", false);
|
||||||
|
@ -3027,47 +2937,22 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
return doRowWise(rowVector, 'p');
|
return doRowWise(rowVector, 'p');
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* In place addition of a column vector
|
|
||||||
*
|
|
||||||
* @param rowVector the row vector to add
|
|
||||||
* @return the result of the addition
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray addRowVector(INDArray rowVector) {
|
public INDArray addRowVector(INDArray rowVector) {
|
||||||
validateNumericalArray("addRowVector", false);
|
validateNumericalArray("addRowVector", false);
|
||||||
return dup().addiRowVector(rowVector);
|
return dup().addiRowVector(rowVector);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform a copy matrix multiplication
|
|
||||||
*
|
|
||||||
* @param other the other matrix to perform matrix multiply with
|
|
||||||
* @return the result of the matrix multiplication
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) {
|
public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) {
|
||||||
return mMulTranspose.exec(this, other, result);
|
return mMulTranspose.exec(this, other, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform a copy matrix multiplication
|
|
||||||
*
|
|
||||||
* @param other the other matrix to perform matrix multiply with
|
|
||||||
* @return the result of the matrix multiplication
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) {
|
public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) {
|
||||||
return mMulTranspose.exec(this, other, null);
|
return mMulTranspose.exec(this, other, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform a copy matrix multiplication
|
|
||||||
*
|
|
||||||
* @param other the other matrix to perform matrix multiply with
|
|
||||||
* @return the result of the matrix multiplication
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mmul(INDArray other) {
|
public INDArray mmul(INDArray other) {
|
||||||
Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType());
|
Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType());
|
||||||
|
@ -3105,8 +2990,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
if(!isVectorOrScalar()) {
|
if(!isVectorOrScalar()) {
|
||||||
throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this));
|
throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return dup().data().asDouble();
|
return dup().data().asDouble();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3115,7 +2998,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
||||||
if(!isVectorOrScalar()) {
|
if(!isVectorOrScalar()) {
|
||||||
throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this));
|
throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this));
|
||||||
}
|
}
|
||||||
|
|
||||||
return dup().data().asFloat();
|
return dup().data().asFloat();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1123,14 +1123,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Perform an copy matrix multiplication
|
|
||||||
*
|
|
||||||
* @param other the other matrix to perform matrix multiply with
|
|
||||||
* @param result the result ndarray
|
|
||||||
* @param mMulTranspose the transpose status of each array
|
|
||||||
* @return the result of the matrix multiplication
|
|
||||||
*/
|
|
||||||
@Override
|
@Override
|
||||||
public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) {
|
public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) {
|
||||||
return null;
|
return null;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
/*******************************************************************************
|
/* *****************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
|
@ -16,9 +16,6 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ndarray;
|
package org.nd4j.linalg.api.ndarray;
|
||||||
|
|
||||||
import static org.nd4j.linalg.factory.Nd4j.compressDebug;
|
|
||||||
import static org.nd4j.linalg.factory.Nd4j.preventUnpack;
|
|
||||||
|
|
||||||
import com.google.flatbuffers.FlatBufferBuilder;
|
import com.google.flatbuffers.FlatBufferBuilder;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||||
|
@ -52,6 +49,7 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
DataBuffer shapeInfoDataBuffer();
|
DataBuffer shapeInfoDataBuffer();
|
||||||
|
|
||||||
|
// TODO: Unused untested method.
|
||||||
/**
|
/**
|
||||||
* Sparse info
|
* Sparse info
|
||||||
* @return Sparse info.
|
* @return Sparse info.
|
||||||
|
@ -110,12 +108,13 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
int elementWiseStride();
|
int elementWiseStride();
|
||||||
|
|
||||||
|
// TODO: Unused untested method.
|
||||||
/**
|
/**
|
||||||
* Get a double at the given linear offset unsafe, without checks.
|
* Get a double at the given linear offset unsafe, without checks.
|
||||||
* @param offset the offset to get at
|
* @param offset the offset to get at
|
||||||
* @return double value at offset
|
* @return double value at offset
|
||||||
*/
|
*/
|
||||||
double getDoubleUnsafe(long offset); //TODO: consider deleting.
|
double getDoubleUnsafe(long offset);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get string value at given index.
|
* Get string value at given index.
|
||||||
|
@ -124,13 +123,14 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
String getString(long index);
|
String getString(long index);
|
||||||
|
|
||||||
|
// TODO: Unused untested method.
|
||||||
/**
|
/**
|
||||||
* Insert a scalar at the given linear offset
|
* Insert a scalar at the given linear offset
|
||||||
* @param offset the offset to insert at
|
* @param offset the offset to insert at
|
||||||
* @param value the value to insert
|
* @param value the value to insert
|
||||||
* @return this
|
* @return this
|
||||||
*/
|
*/
|
||||||
INDArray putScalarUnsafe(long offset, double value); //TODO: consider deleting.
|
INDArray putScalarUnsafe(long offset, double value);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the number of possible vectors for a given dimension
|
* Returns the number of possible vectors for a given dimension
|
||||||
|
@ -190,6 +190,7 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray assign(INDArray arr);
|
INDArray assign(INDArray arr);
|
||||||
|
|
||||||
|
// TODO: Unused untested method.
|
||||||
/**
|
/**
|
||||||
* Assign all elements from given ndarray that are matching given condition,
|
* Assign all elements from given ndarray that are matching given condition,
|
||||||
* ndarray to this ndarray
|
* ndarray to this ndarray
|
||||||
|
@ -553,7 +554,7 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*
|
*
|
||||||
* @param n the number to subtract by
|
* @param n the number to subtract by
|
||||||
* @param result the result ndarray
|
* @param result the result ndarray
|
||||||
* @return
|
* @return the result ndarray
|
||||||
*/
|
*/
|
||||||
INDArray rsub(Number n, INDArray result);
|
INDArray rsub(Number n, INDArray result);
|
||||||
|
|
||||||
|
@ -1041,7 +1042,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray divRowVector(INDArray rowVector);
|
INDArray divRowVector(INDArray rowVector);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In place reverse divison of a column vector
|
* In place reverse divison of a column vector
|
||||||
*
|
*
|
||||||
|
@ -1066,6 +1066,7 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray rdiviRowVector(INDArray rowVector);
|
INDArray rdiviRowVector(INDArray rowVector);
|
||||||
|
|
||||||
|
//TODO: unused / untested method.
|
||||||
/**
|
/**
|
||||||
* Reverse division of a column vector (copy)
|
* Reverse division of a column vector (copy)
|
||||||
*
|
*
|
||||||
|
@ -1074,7 +1075,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray rdivRowVector(INDArray rowVector);
|
INDArray rdivRowVector(INDArray rowVector);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In place multiplication of a column vector
|
* In place multiplication of a column vector
|
||||||
*
|
*
|
||||||
|
@ -1107,7 +1107,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray mulRowVector(INDArray rowVector);
|
INDArray mulRowVector(INDArray rowVector);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In place reverse subtraction of a column vector
|
* In place reverse subtraction of a column vector
|
||||||
*
|
*
|
||||||
|
@ -1132,6 +1131,7 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray rsubiRowVector(INDArray rowVector);
|
INDArray rsubiRowVector(INDArray rowVector);
|
||||||
|
|
||||||
|
//TODO: unused / untested method.
|
||||||
/**
|
/**
|
||||||
* Reverse subtraction of a row vector (copy)
|
* Reverse subtraction of a row vector (copy)
|
||||||
*
|
*
|
||||||
|
@ -1180,7 +1180,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray addiColumnVector(INDArray columnVector);
|
INDArray addiColumnVector(INDArray columnVector);
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In place assignment of a column vector
|
* In place assignment of a column vector
|
||||||
*
|
*
|
||||||
|
@ -1221,6 +1220,12 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray addRowVector(INDArray rowVector);
|
INDArray addRowVector(INDArray rowVector);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform a copy matrix multiplication
|
||||||
|
*
|
||||||
|
* @param other the other matrix to perform matrix multiply with
|
||||||
|
* @return the result of the matrix multiplication
|
||||||
|
*/
|
||||||
INDArray mmul(INDArray other, MMulTranspose mMulTranspose);
|
INDArray mmul(INDArray other, MMulTranspose mMulTranspose);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1231,8 +1236,6 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
INDArray mmul(INDArray other);
|
INDArray mmul(INDArray other);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert this ndarray to a 2d double matrix.
|
* Convert this ndarray to a 2d double matrix.
|
||||||
* Note that THIS SHOULD NOT BE USED FOR SPEED.
|
* Note that THIS SHOULD NOT BE USED FOR SPEED.
|
||||||
|
@ -1283,6 +1286,14 @@ public interface INDArray extends Serializable, AutoCloseable {
|
||||||
*/
|
*/
|
||||||
int[] toIntVector();
|
int[] toIntVector();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert this ndarray to a 1d long matrix.
|
||||||
|
* Note that THIS SHOULD NOT BE USED FOR SPEED.
|
||||||
|
* This is mainly used for integrations with other libraries.
|
||||||
|
* Due to nd4j's off heap nature, moving data on heap is very expensive
|
||||||
|
* and should not be used if possible.
|
||||||
|
* @return a copy of this array as a 1d long array
|
||||||
|
*/
|
||||||
long[] toLongVector();
|
long[] toLongVector();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -16,17 +16,13 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.scalar;
|
package org.nd4j.linalg.api.ops.impl.scalar;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import java.util.Collections;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.graph.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.LinkedHashMap;
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.BaseScalarOp;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -108,8 +104,7 @@ public class LeakyReLU extends BaseScalarOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
SDVariable ret = f().leakyReluDerivative(arg(), alpha).mul(i_v.get(0));
|
return Collections.singletonList(f().leakyReluBp(arg(), i_v.get(0), alpha));
|
||||||
return Arrays.asList(ret);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -75,15 +75,8 @@ public class RectifiedLinear extends BaseScalarOp {
|
||||||
return "Relu";
|
return "Relu";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
if(scalarValue.getDouble(0) == 0.0){
|
return Collections.singletonList(f().thresholdReluBp(arg(), i_v.get(0), scalarValue.getDouble(0)));
|
||||||
return Collections.singletonList(f().reluDerivative(arg(), i_v.get(0)));
|
|
||||||
} else {
|
|
||||||
SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0];
|
|
||||||
SDVariable ret = step.mul(i_v.get(0));
|
|
||||||
return Collections.singletonList(ret);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,10 +3,10 @@ package org.nd4j.linalg.api.ops.impl.scalar;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.shade.guava.base.Preconditions;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -30,7 +30,8 @@ public class RectifiedLinearDerivative extends DynamicCustomOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
Preconditions.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
return Collections.singletonList(dataTypes.get(0));
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
|
|
@ -0,0 +1,77 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Threshold ReLU op. The genral case of {@link RectifiedLinear}.
|
||||||
|
*/
|
||||||
|
public class ThresholdRelu extends DynamicCustomOp {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private double cutoff = 0.0;
|
||||||
|
|
||||||
|
public ThresholdRelu(){ }
|
||||||
|
|
||||||
|
public ThresholdRelu(SameDiff sd, SDVariable input, boolean inPlace, double cutoff){
|
||||||
|
super(sd, new SDVariable[]{input}, inPlace);
|
||||||
|
this.cutoff = cutoff;
|
||||||
|
addTArgument(cutoff);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ThresholdRelu(SameDiff sd, SDVariable input, double cutoff){
|
||||||
|
super(sd, new SDVariable[]{input});
|
||||||
|
this.cutoff = cutoff;
|
||||||
|
addTArgument(cutoff);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ThresholdRelu(@NonNull INDArray input, INDArray output, double cutoff){
|
||||||
|
super(new INDArray[]{input}, wrapOrNull(output));
|
||||||
|
this.cutoff = cutoff;
|
||||||
|
addTArgument(cutoff);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "thresholdedrelu";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
return Collections.singletonList(f().thresholdReluBp(arg(), f1.get(0), cutoff));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Cube backpropagation op - dL/dIn from in and dL/dOut
|
||||||
|
*/
|
||||||
|
public class CubeBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public CubeBp(){ }
|
||||||
|
|
||||||
|
public CubeBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||||
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
}
|
||||||
|
|
||||||
|
public CubeBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "cube_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -27,7 +27,11 @@ import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Cube derivative, e.g. 3x^2
|
* Cube derivative, e.g. 3x^2
|
||||||
|
*
|
||||||
|
* @deprecated Use {@link CubeBp}
|
||||||
|
*
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public class CubeDerivative extends BaseTransformStrictOp {
|
public class CubeDerivative extends BaseTransformStrictOp {
|
||||||
public CubeDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
public CubeDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||||
super(sameDiff, i_v, inPlace);
|
super(sameDiff, i_v, inPlace);
|
||||||
|
|
|
@ -1,84 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* Derivative of ELU: Exponential Linear Unit (alpha=1.0)<br>
|
|
||||||
* Introduced in paper:<br>
|
|
||||||
* Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)<br>
|
|
||||||
* Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter (2015)<br>
|
|
||||||
* <a href="http://arxiv.org/abs/1511.07289">http://arxiv.org/abs/1511.07289</a>
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
public class ELUDerivative extends BaseTransformStrictOp {
|
|
||||||
public ELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
|
||||||
super(sameDiff, i_v, inPlace);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ELUDerivative() {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public ELUDerivative(INDArray x, INDArray z) {
|
|
||||||
super(x, z);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ELUDerivative(INDArray x) {
|
|
||||||
super(x);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "eluderivative";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
|
||||||
SDVariable ret = sameDiff.zerosLike(arg());
|
|
||||||
return Collections.singletonList(ret);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* ELU backpropagation op - dL/dIn from in and dL/dOut
|
||||||
|
*/
|
||||||
|
public class EluBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public EluBp(){ }
|
||||||
|
|
||||||
|
public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||||
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
}
|
||||||
|
|
||||||
|
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) {
|
||||||
|
this(input, gradient, output, 1.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double alpha){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
addTArgument(alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "elu_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hard Sigmoid backpropagation op - dL/dIn from in and dL/dOut
|
||||||
|
*/
|
||||||
|
public class HardSigmoidBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public HardSigmoidBp(){ }
|
||||||
|
|
||||||
|
public HardSigmoidBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||||
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
}
|
||||||
|
|
||||||
|
public HardSigmoidBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "hardsigmoid_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -29,8 +29,11 @@ import java.util.List;
|
||||||
/**
|
/**
|
||||||
* HardSigmoid derivative
|
* HardSigmoid derivative
|
||||||
*
|
*
|
||||||
|
* @deprecated Use {@link HardSigmoidBp}
|
||||||
|
*
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public class HardSigmoidDerivative extends BaseTransformStrictOp {
|
public class HardSigmoidDerivative extends BaseTransformStrictOp {
|
||||||
public HardSigmoidDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
public HardSigmoidDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||||
super(sameDiff, i_v, inPlace);
|
super(sameDiff, i_v, inPlace);
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hard Tanh backpropagation op - dL/dIn from in and dL/dOut
|
||||||
|
*/
|
||||||
|
public class HardTanhBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public HardTanhBp(){ }
|
||||||
|
|
||||||
|
public HardTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||||
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
}
|
||||||
|
|
||||||
|
public HardTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "hardtanh_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -31,8 +31,11 @@ import java.util.List;
|
||||||
/**
|
/**
|
||||||
* Hard tanh elementwise derivative function
|
* Hard tanh elementwise derivative function
|
||||||
*
|
*
|
||||||
|
* @deprecated Use {@link HardTanhBp}
|
||||||
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public class HardTanhDerivative extends BaseTransformStrictOp {
|
public class HardTanhDerivative extends BaseTransformStrictOp {
|
||||||
public HardTanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
public HardTanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||||
super(sameDiff, i_v, inPlace);
|
super(sameDiff, i_v, inPlace);
|
||||||
|
|
|
@ -0,0 +1,68 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* LReLU backpropagation op - dL/dIn from in and dL/dOut
|
||||||
|
*/
|
||||||
|
public class LeakyReLUBp extends DynamicCustomOp {
|
||||||
|
public static final double DEFAULT_ALPHA = 0.01;
|
||||||
|
private double alpha = DEFAULT_ALPHA;
|
||||||
|
|
||||||
|
public LeakyReLUBp(){ }
|
||||||
|
|
||||||
|
public LeakyReLUBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){
|
||||||
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
this.alpha = alpha;
|
||||||
|
addTArgument(alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
public LeakyReLUBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double alpha){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
this.alpha = alpha;
|
||||||
|
addTArgument(alpha);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "lrelu_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rational Tanh backpropagation op - dL/dIn from in and dL/dOut
|
||||||
|
*/
|
||||||
|
public class RationalTanhBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public RationalTanhBp(){ }
|
||||||
|
|
||||||
|
public RationalTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||||
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
}
|
||||||
|
|
||||||
|
public RationalTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "rationaltanh_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -31,9 +31,12 @@ import java.util.List;
|
||||||
* Rational Tanh Derivative, as described at https://github.com/deeplearning4j/libnd4j/issues/351
|
* Rational Tanh Derivative, as described at https://github.com/deeplearning4j/libnd4j/issues/351
|
||||||
* Calculates dOut/dIn given input, not dL/dIn given dL/dOut and input
|
* Calculates dOut/dIn given input, not dL/dIn given dL/dOut and input
|
||||||
*
|
*
|
||||||
|
* @deprecated Use {@link RationalTanhBp}
|
||||||
|
*
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
* @author AlexDBlack
|
* @author AlexDBlack
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public class RationalTanhDerivative extends BaseTransformStrictOp {
|
public class RationalTanhDerivative extends BaseTransformStrictOp {
|
||||||
public RationalTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) {
|
public RationalTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) {
|
||||||
super(sameDiff, in, inPlace);
|
super(sameDiff, in, inPlace);
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rectified Tanh backpropagation op - dL/dIn from in and dL/dOut
|
||||||
|
*/
|
||||||
|
public class RectifiedTanhBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public RectifiedTanhBp(){ }
|
||||||
|
|
||||||
|
public RectifiedTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||||
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
}
|
||||||
|
|
||||||
|
public RectifiedTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "rectifiedtanh_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -30,9 +30,12 @@ import java.util.List;
|
||||||
/**
|
/**
|
||||||
* Rectified Tanh Derivative
|
* Rectified Tanh Derivative
|
||||||
*
|
*
|
||||||
|
* @deprecated Use {@link RectifiedTanhBp}
|
||||||
|
*
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
* @author AlexDBlack
|
* @author AlexDBlack
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public class RectifiedTanhDerivative extends BaseTransformStrictOp {
|
public class RectifiedTanhDerivative extends BaseTransformStrictOp {
|
||||||
public RectifiedTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) {
|
public RectifiedTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) {
|
||||||
super(sameDiff, in, inPlace);
|
super(sameDiff, in, inPlace);
|
||||||
|
|
|
@ -16,15 +16,18 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Derivative of Rectified linear unit 6, i.e. min(max(input, cutoff), 6), where cutoff can be chosen.
|
* Derivative of Rectified linear unit 6, i.e. min(max(input, cutoff), 6), where cutoff can be chosen.
|
||||||
|
@ -33,7 +36,9 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
public class Relu6Derivative extends DynamicCustomOp {
|
public class Relu6Derivative extends DynamicCustomOp {
|
||||||
|
|
||||||
private double cutoff = 0.0;
|
private static final double DEFAULT_CUTOFF = 0.0;
|
||||||
|
|
||||||
|
private double cutoff = DEFAULT_CUTOFF;
|
||||||
|
|
||||||
public Relu6Derivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, double cutoff) {
|
public Relu6Derivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, double cutoff) {
|
||||||
super("relu6_bp", sameDiff, new SDVariable[]{i_v1, i_v2});
|
super("relu6_bp", sameDiff, new SDVariable[]{i_v1, i_v2});
|
||||||
|
@ -45,6 +50,16 @@ public class Relu6Derivative extends DynamicCustomOp {
|
||||||
this.extraArgs = new Object[]{cutoff};
|
this.extraArgs = new Object[]{cutoff};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public Relu6Derivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||||
|
this(input, gradient, output, DEFAULT_CUTOFF);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Relu6Derivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double cutoff){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
this.cutoff = cutoff;
|
||||||
|
this.extraArgs = new Object[]{cutoff};
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int opNum() {
|
public int opNum() {
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
@ -31,8 +31,11 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
* https://arxiv.org/pdf/1706.02515.pdf
|
* https://arxiv.org/pdf/1706.02515.pdf
|
||||||
*
|
*
|
||||||
|
* @deprecated Use {@link SeluBp}
|
||||||
|
*
|
||||||
* @author raver119@gmail.com
|
* @author raver119@gmail.com
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public class SELUDerivative extends BaseTransformStrictOp {
|
public class SELUDerivative extends BaseTransformStrictOp {
|
||||||
|
|
||||||
private static final double SELU_ALPHA = 1.6732632423543772848170429916717;
|
private static final double SELU_ALPHA = 1.6732632423543772848170429916717;
|
||||||
|
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SELU backpropagation op - dL/dIn from in and dL/dOut
|
||||||
|
*/
|
||||||
|
public class SeluBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public SeluBp(){ }
|
||||||
|
|
||||||
|
public SeluBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||||
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
}
|
||||||
|
|
||||||
|
public SeluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "selu_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,62 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SoftPlus backpropagation op - dL/dIn from in and dL/dOut
|
||||||
|
*/
|
||||||
|
public class SoftPlusBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public SoftPlusBp(){ }
|
||||||
|
|
||||||
|
public SoftPlusBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||||
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
}
|
||||||
|
|
||||||
|
public SoftPlusBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "softplus_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,63 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SoftSign backpropagation op - dL/dIn from in and dL/dOut
|
||||||
|
*/
|
||||||
|
public class SoftSignBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
public SoftSignBp(){ }
|
||||||
|
|
||||||
|
public SoftSignBp(SameDiff sd, SDVariable input, SDVariable gradient){
|
||||||
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
}
|
||||||
|
|
||||||
|
public SoftSignBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "softsign_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -29,7 +29,10 @@ import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* SoftSign derivative.
|
* SoftSign derivative.
|
||||||
|
*
|
||||||
|
* @deprecated Use {@link SoftSignBp}
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public class SoftSignDerivative extends BaseTransformStrictOp {
|
public class SoftSignDerivative extends BaseTransformStrictOp {
|
||||||
public SoftSignDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
public SoftSignDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||||
super(sameDiff, i_v, inPlace);
|
super(sameDiff, i_v, inPlace);
|
||||||
|
|
|
@ -16,10 +16,12 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import lombok.NonNull;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -40,6 +42,12 @@ public class SoftmaxBp extends DynamicCustomOp {
|
||||||
addIArgument(dimension);
|
addIArgument(dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Integer dimension){
|
||||||
|
super(new INDArray[]{input, grad}, wrapOrNull(output));
|
||||||
|
if(dimension != null)
|
||||||
|
addIArgument(dimension);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
return "softmax_bp";
|
return "softmax_bp";
|
||||||
|
|
|
@ -35,15 +35,15 @@ public class TanhDerivative extends DynamicCustomOp {
|
||||||
super(sameDiff, new SDVariable[]{i_v1, i_v2});
|
super(sameDiff, new SDVariable[]{i_v1, i_v2});
|
||||||
}
|
}
|
||||||
|
|
||||||
public TanhDerivative(INDArray x, INDArray z) {
|
public TanhDerivative(INDArray x, INDArray y, INDArray z) {
|
||||||
super(null, x, z, null, null);
|
super(null, new INDArray[]{x, y}, new INDArray[]{z});
|
||||||
}
|
}
|
||||||
|
|
||||||
public TanhDerivative() {
|
public TanhDerivative() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public TanhDerivative(INDArray x) {
|
public TanhDerivative(INDArray x, INDArray y) {
|
||||||
this(x, null);
|
this(x, y, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import lombok.Getter;
|
||||||
|
import lombok.NonNull;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Threshold ReLU Backprop op - dL/dIn from in and dL/dOut
|
||||||
|
*
|
||||||
|
* For {@link RectifiedLinear} as well as {@link ThresholdRelu}.
|
||||||
|
*/
|
||||||
|
public class ThresholdReluBp extends DynamicCustomOp {
|
||||||
|
|
||||||
|
@Getter
|
||||||
|
private double cutoff = 0;
|
||||||
|
|
||||||
|
public ThresholdReluBp(){ }
|
||||||
|
|
||||||
|
public ThresholdReluBp(SameDiff sd, SDVariable input, SDVariable gradient, double cutoff){
|
||||||
|
super(sd, new SDVariable[]{input, gradient});
|
||||||
|
this.cutoff = cutoff;
|
||||||
|
addTArgument(cutoff);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ThresholdReluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double cutoff){
|
||||||
|
super(new INDArray[]{input, gradient}, wrapOrNull(output));
|
||||||
|
this.cutoff = cutoff;
|
||||||
|
addTArgument(cutoff);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName(){
|
||||||
|
return "thresholdedrelu_bp";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions
|
||||||
|
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
|
||||||
|
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
|
||||||
|
|
||||||
|
return Collections.singletonList(dataTypes.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
|
throw new UnsupportedOperationException("Not supported");
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.same;
|
package org.nd4j.linalg.api.ops.impl.transforms.same;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
@ -70,7 +71,6 @@ public class Cube extends BaseTransformSameOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
SDVariable g = f().mul(f().cubeDerivative(arg()),f1.get(0));
|
return Collections.singletonList(f().cubeBp(arg(), f1.get(0)));
|
||||||
return Arrays.asList(g);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,14 +18,18 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||||
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* ELU: Exponential Linear Unit (alpha=1.0)<br>
|
* ELU: Exponential Linear Unit (alpha=1.0)<br>
|
||||||
|
@ -36,25 +40,20 @@ import java.util.List;
|
||||||
*
|
*
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
public class ELU extends BaseTransformStrictOp {
|
public class ELU extends DynamicCustomOp {
|
||||||
public ELU(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
public ELU(SameDiff sameDiff, SDVariable i_v) {
|
||||||
super(sameDiff, i_v, inPlace);
|
super(sameDiff, new SDVariable[]{i_v});
|
||||||
}
|
}
|
||||||
|
|
||||||
public ELU() {
|
public ELU() {
|
||||||
}
|
}
|
||||||
|
|
||||||
public ELU(INDArray x, INDArray z) {
|
public ELU(INDArray x, INDArray z) {
|
||||||
super(x, z);
|
super(null, wrapOrNull(x), wrapOrNull(z));
|
||||||
}
|
}
|
||||||
|
|
||||||
public ELU(INDArray x) {
|
public ELU(INDArray x) {
|
||||||
super(x);
|
this(x, null);
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 35;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -76,8 +75,14 @@ public class ELU extends BaseTransformStrictOp {
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
//ELU: e^x-1 if x<0, x otherwise
|
//ELU: e^x-1 if x<0, x otherwise
|
||||||
//dL/dIn = dL/Out * dOut/dIn
|
//dL/dIn = dL/Out * dOut/dIn
|
||||||
SDVariable ret = f().eluDerivative(arg()).mul(i_v.get(0));
|
return Collections.singletonList(f().eluBp(arg(), i_v.get(0)));
|
||||||
return Arrays.asList(ret);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
|
||||||
|
Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 datatype for ELU, got %s", dataTypes);
|
||||||
|
Preconditions.checkState(dataTypes.get(0).isFPType(), "Expected floating point input type for ELU, got %s", dataTypes);
|
||||||
|
|
||||||
|
return dataTypes;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,9 +69,7 @@ public class HardSigmoid extends BaseTransformStrictOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
SDVariable in = arg();
|
return Collections.singletonList(f().hardSigmoidBp(arg(), f1.get(0)));
|
||||||
SDVariable dOutdIn = new HardSigmoidDerivative(sameDiff, in, false).outputVariables()[0];
|
|
||||||
return Collections.singletonList(dOutdIn.mul(f1.get(0)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -70,7 +71,6 @@ public class HardTanh extends BaseTransformStrictOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
SDVariable ret = f().hardTanhDerivative(arg()).mul(i_v.get(0));
|
return Collections.singletonList(f().hardTanhBp(arg(), i_v.get(0)));
|
||||||
return Arrays.asList(ret);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,13 +16,10 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||||
|
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
||||||
|
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
@ -71,6 +68,6 @@ public class RationalTanh extends BaseTransformStrictOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return Collections.singletonList(f().tanhRationalDerivative(arg()).mul(f1.get(0)));
|
return Collections.singletonList(f().tanhRationalBp(arg(), f1.get(0)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,13 +17,10 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||||
|
|
||||||
import onnx.Onnx;
|
import onnx.Onnx;
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformOp;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -88,6 +85,6 @@ public class RectifiedTanh extends BaseTransformStrictOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return Collections.singletonList(f().tanhRectifiedDerivative(arg()).mul(f1.get(0)));
|
return Collections.singletonList(f().tanhRectifiedBp(arg(), f1.get(0)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -76,8 +77,7 @@ public class SELU extends BaseTransformStrictOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
SDVariable ret = f().seluDerivative(arg()).mul(i_v.get(0));
|
return Collections.singletonList(f().seluBp(arg(), i_v.get(0)));
|
||||||
return Arrays.asList(ret);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,8 +28,11 @@ import java.util.List;
|
||||||
/**
|
/**
|
||||||
* Sigmoid derivative
|
* Sigmoid derivative
|
||||||
*
|
*
|
||||||
|
* @deprecated Use {@link org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative}
|
||||||
|
*
|
||||||
* @author Adam Gibson
|
* @author Adam Gibson
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public class SigmoidDerivative extends BaseTransformStrictOp {
|
public class SigmoidDerivative extends BaseTransformStrictOp {
|
||||||
public SigmoidDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
public SigmoidDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
|
||||||
super(sameDiff, i_v1, i_v2);
|
super(sameDiff, i_v1, i_v2);
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
package org.nd4j.linalg.api.ops.impl.transforms.strict;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -73,8 +74,7 @@ public class SoftSign extends BaseTransformStrictOp {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||||
SDVariable ret = f().softsignDerivative(arg()).mul(i_v.get(0));
|
return Collections.singletonList(f().softsignBp(arg(), i_v.get(0)));
|
||||||
return Arrays.asList(ret);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,7 +27,10 @@ import java.util.List;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tanh derivative
|
* Tanh derivative
|
||||||
|
*
|
||||||
|
* @deprecated Use {@link org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative}.
|
||||||
*/
|
*/
|
||||||
|
@Deprecated
|
||||||
public class TanhDerivative extends BaseTransformStrictOp {
|
public class TanhDerivative extends BaseTransformStrictOp {
|
||||||
public TanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
public TanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
|
||||||
super(sameDiff, i_v, inPlace);
|
super(sameDiff, i_v, inPlace);
|
||||||
|
|
|
@ -37,7 +37,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
|
||||||
|
@ -438,16 +438,16 @@ public class Transforms {
|
||||||
|
|
||||||
|
|
||||||
public static INDArray elu(INDArray in, boolean copy) {
|
public static INDArray elu(INDArray in, boolean copy) {
|
||||||
return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)));
|
return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
public static INDArray eluDerivative(INDArray arr) {
|
public static INDArray eluDerivative(INDArray arr, INDArray grad) {
|
||||||
return eluDerivative(arr, true);
|
return eluDerivative(arr, grad,true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static INDArray eluDerivative(INDArray in, boolean copy) {
|
public static INDArray eluDerivative(INDArray in, INDArray grad, boolean copy) {
|
||||||
return Nd4j.getExecutioner().exec(new ELUDerivative(in, (copy ? in.ulike() : in)));
|
return Nd4j.getExecutioner().exec(new EluBp(in, grad, (copy ? in.ulike() : in)))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,7 @@ public class LecunUniformInitScheme extends BaseWeightInitScheme {
|
||||||
@Override
|
@Override
|
||||||
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
||||||
double b = 3.0 / Math.sqrt(fanIn);
|
double b = 3.0 / Math.sqrt(fanIn);
|
||||||
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-b, b));
|
return Nd4j.rand(Nd4j.getDistributions().createUniform(-b, b), shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ public class ReluUniformInitScheme extends BaseWeightInitScheme {
|
||||||
@Override
|
@Override
|
||||||
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
||||||
double u = Math.sqrt(6.0 / fanIn);
|
double u = Math.sqrt(6.0 / fanIn);
|
||||||
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
|
return Nd4j.rand(Nd4j.getDistributions().createUniform(-u, u), shape); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -46,7 +46,7 @@ public class SigmoidUniformInitScheme extends BaseWeightInitScheme {
|
||||||
@Override
|
@Override
|
||||||
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
||||||
double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut));
|
double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut));
|
||||||
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-r, r));
|
return Nd4j.rand(Nd4j.getDistributions().createUniform(-r, r), shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,7 @@ public class UniformInitScheme extends BaseWeightInitScheme {
|
||||||
@Override
|
@Override
|
||||||
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
|
||||||
double a = 1.0 / Math.sqrt(fanIn);
|
double a = 1.0 / Math.sqrt(fanIn);
|
||||||
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-a, a));
|
return Nd4j.rand(Nd4j.getDistributions().createUniform(-a, a), shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue