Merge remote-tracking branch 'fork/master'

master
AlexDBlack 2019-09-02 18:52:12 +10:00
commit 7ded4416cb
110 changed files with 1363 additions and 621 deletions

View File

@ -248,7 +248,7 @@ public class ValidateCuDNN extends BaseDL4JTest {
Nd4j.getRandom().setSeed(12345); Nd4j.getRandom().setSeed(12345);
INDArray features = Nd4j.rand(fShape); INDArray features = Nd4j.rand(fShape);
INDArray labels = Nd4j.rand(lShape); INDArray labels = Nd4j.rand(lShape);
labels = Nd4j.exec(new IsMax(labels, 1)); labels = Nd4j.exec(new IsMax(labels, 1))[0].castTo(features.dataType());
List<CuDNNValidationUtil.TestCase> testCaseList = new ArrayList<>(); List<CuDNNValidationUtil.TestCase> testCaseList = new ArrayList<>();
@ -256,7 +256,7 @@ public class ValidateCuDNN extends BaseDL4JTest {
for (int i = 0; i < 6; i++) { for (int i = 0; i < 6; i++) {
INDArray f = Nd4j.rand(fShape); INDArray f = Nd4j.rand(fShape);
INDArray l = Nd4j.rand(lShape); INDArray l = Nd4j.rand(lShape);
Nd4j.exec(new IsMax(l, 1))[0]; l = Nd4j.exec(new IsMax(l, 1))[0].castTo(features.dataType());
dataSets.add(new DataSet(f, l)); dataSets.add(new DataSet(f, l));
} }
DataSetIterator iter = new ExistingDataSetIterator(dataSets); DataSetIterator iter = new ExistingDataSetIterator(dataSets);

View File

@ -25,7 +25,6 @@ import org.deeplearning4j.models.word2vec.Huffman;
import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction; import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunction;
import org.deeplearning4j.spark.models.embeddings.word2vec.FirstIterationFunctionAdapter;
import org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction; import org.deeplearning4j.spark.models.embeddings.word2vec.MapToPairFunction;
import org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec; import org.deeplearning4j.spark.models.embeddings.word2vec.Word2Vec;
import org.deeplearning4j.spark.text.functions.CountCumSum; import org.deeplearning4j.spark.text.functions.CountCumSum;
@ -470,11 +469,11 @@ public class TextPipelineTest extends BaseSparkTest {
Iterator<Tuple2<List<VocabWord>, Long>> iterator = vocabWordListSentenceCumSumRDD.collect().iterator(); Iterator<Tuple2<List<VocabWord>, Long>> iterator = vocabWordListSentenceCumSumRDD.collect().iterator();
FirstIterationFunctionAdapter firstIterationFunction = new FirstIterationFunctionAdapter( FirstIterationFunction firstIterationFunction = new FirstIterationFunction(
word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache()); word2vecVarMapBroadcast, expTableBroadcast, pipeline.getBroadCastVocabCache());
Iterable<Map.Entry<VocabWord, INDArray>> ret = firstIterationFunction.call(iterator); Iterator<Map.Entry<VocabWord, INDArray>> ret = firstIterationFunction.call(iterator);
assertTrue(ret.iterator().hasNext()); assertTrue(ret.hasNext());
} }
@Test @Test

View File

@ -70,6 +70,13 @@
<artifactId>deeplearning4j-play_2.11</artifactId> <artifactId>deeplearning4j-play_2.11</artifactId>
<version>${deeplearning4j.version}</version> <version>${deeplearning4j.version}</version>
<scope>test</scope> <scope>test</scope>
<exclusions>
<!-- To avoid clashing net.jpountz.lz4:lz4:1.3.0 and org.lz4:lz4-java:jar:1.4.0 - -->
<exclusion>
<groupId>net.jpountz.lz4</groupId>
<artifactId>lz4</artifactId>
</exclusion>
</exclusions>
</dependency> </dependency>
<dependency> <dependency>

View File

@ -273,9 +273,11 @@ namespace nd4j {
* @param writeList * @param writeList
* @param readList * @param readList
*/ */
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
static void registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList); static void registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
static void prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false); static void prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
static void registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList); static void registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
static void preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false); static void preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);

View File

@ -931,13 +931,13 @@ void initializeFunctions(Nd4jPointer *functions) {
Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) { Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
Nd4jPointer pointer; Nd4jPointer pointer;
// cudaHostAllocMapped |cudaHostAllocPortable // cudaHostAllocMapped |cudaHostAllocPortable
auto res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize, cudaHostAllocDefault); auto res = cudaHostAlloc(reinterpret_cast<void **>(&pointer), memorySize + 8, cudaHostAllocDefault);
if (res != 0) { if (res != 0) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaHostAlloc failed"); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaHostAlloc failed");
} }
return pointer; return reinterpret_cast<int8_t*>(pointer);
} }
/** /**
@ -950,13 +950,13 @@ Nd4jPointer mallocHost(Nd4jLong memorySize, int flags) {
*/ */
Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) { Nd4jPointer mallocDevice(Nd4jLong memorySize, int deviceId, int flags) {
Nd4jPointer pointer; Nd4jPointer pointer;
auto res = cudaMalloc(reinterpret_cast<void **>(&pointer), memorySize); auto res = cudaMalloc(reinterpret_cast<void **>(&pointer), memorySize + 8);
if (res != 0) { if (res != 0) {
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(res);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMalloc failed"); nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("cudaMalloc failed");
} }
return pointer; return reinterpret_cast<int8_t*>(pointer);
} }
/** /**

View File

@ -96,13 +96,13 @@ namespace functions {
} }
else { else {
if(vx == vz) { if(vx == vz) {
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
z[xOffset] = OpType::op(x[xOffset], params); z[xOffset] = OpType::op(x[xOffset], params);
} }
} }
else { else {
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
z[zOffset] = OpType::op(x[xOffset], params); z[zOffset] = OpType::op(x[xOffset], params);

View File

@ -94,13 +94,13 @@ namespace functions {
} }
else { else {
if(vx == vz) { if(vx == vz) {
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
z[xOffset] = OpType::op(x[xOffset], params); z[xOffset] = OpType::op(x[xOffset], params);
} }
} }
else { else {
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
z[zOffset] = OpType::op(x[xOffset], params); z[zOffset] = OpType::op(x[xOffset], params);

View File

@ -96,13 +96,13 @@ namespace functions {
} }
else { else {
if(vx == vz) { if(vx == vz) {
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
z[xOffset] = OpType::op(x[xOffset], params); z[xOffset] = OpType::op(x[xOffset], params);
} }
} }
else { else {
for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { for (Nd4jLong i = tid; i < length; i+= totalThreads) {
auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto xOffset = shape::getIndexOffset(i, xShapeInfo, length);
auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo, length);
z[zOffset] = OpType::op(x[xOffset], params); z[zOffset] = OpType::op(x[xOffset], params);

View File

@ -116,7 +116,6 @@
#define TRANSFORM_STRICT_OPS \ #define TRANSFORM_STRICT_OPS \
(3, ELUDerivative), \
(4, TanhDerivative), \ (4, TanhDerivative), \
(5, HardTanhDerivative), \ (5, HardTanhDerivative), \
(6, SigmoidDerivative), \ (6, SigmoidDerivative), \
@ -148,7 +147,6 @@
(32, ATan), \ (32, ATan), \
(33, HardTanh), \ (33, HardTanh), \
(34, SoftSign), \ (34, SoftSign), \
(35, ELU), \
(36, HardSigmoid), \ (36, HardSigmoid), \
(37, RationalTanh) ,\ (37, RationalTanh) ,\
(38, RectifiedTanh) ,\ (38, RectifiedTanh) ,\
@ -211,6 +209,8 @@
(4, ReverseDivide),\ (4, ReverseDivide),\
(5, ReverseSubtract),\ (5, ReverseSubtract),\
(6, MaxPairwise),\ (6, MaxPairwise),\
(7, ELU), \
(8, ELUDerivative), \
(13, MinPairwise),\ (13, MinPairwise),\
(14, CopyPws),\ (14, CopyPws),\
(15, Mod),\ (15, Mod),\

View File

@ -25,12 +25,14 @@
#include <ops/declarable/helpers/legacy_helpers.h> #include <ops/declarable/helpers/legacy_helpers.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(elu, 1, 1, true, 0, 0) { CONFIGURABLE_OP_IMPL(elu, 1, 1, true, -2, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
input->applyTransform(nd4j::transform::ELU, output, nullptr); const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f;
STORE_RESULT(output);
input->applyScalar(nd4j::scalar::ELU, alpha, output);
return Status::OK(); return Status::OK();
} }
@ -41,14 +43,18 @@ namespace nd4j {
->setAllowedOutputTypes(0, {ALL_FLOATS}); ->setAllowedOutputTypes(0, {ALL_FLOATS});
} }
CONFIGURABLE_OP_IMPL(elu_bp, 2, 1, true, 0, 0) { CONFIGURABLE_OP_IMPL(elu_bp, 2, 1, true, -2, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto epsilon = INPUT_VARIABLE(1); auto epsilon = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
const auto alpha = block.numT() > 0 ? T_ARG(0) : 1.f;
// input->applyPairwiseTransform(pairwise::ELUDerivativeE, epsilon, output);
helpers::eluDerivative(block.launchContext(), input, epsilon, output, alpha);
//input->applyPairwiseTransform(pairwise::ELUDerivativeE, epsilon, z, nullptr);
helpers::eluDerivative(block.launchContext(), input, epsilon, z);
return Status::OK(); return Status::OK();
} }

View File

@ -25,15 +25,15 @@
#include <ops/declarable/helpers/legacy_helpers.h> #include <ops/declarable/helpers/legacy_helpers.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(lrelu, 1, 1, true, 1, 0) { CONFIGURABLE_OP_IMPL(lrelu, 1, 1, true, -2, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
float t = block.numT() > 0 ? T_ARG(0) : 0.0f; float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f;
input->applyScalar(nd4j::scalar::LeakyRELU, t, output); input->applyScalar(nd4j::scalar::LeakyRELU, alpha, output);
STORE_RESULT(output); STORE_RESULT(output);
return Status::OK(); return Status::OK();
} }
@ -42,15 +42,17 @@ namespace nd4j {
->setAllowedInputTypes(0, DataType::ANY) ->setAllowedInputTypes(0, DataType::ANY)
->setAllowedOutputTypes(0, {ALL_FLOATS}); ->setAllowedOutputTypes(0, {ALL_FLOATS});
} }
CONFIGURABLE_OP_IMPL(lrelu_bp, 2, 1, true, 0, 0) { CONFIGURABLE_OP_IMPL(lrelu_bp, 2, 1, true, -2, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto epsilon = INPUT_VARIABLE(1); auto epsilon = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0);
float alpha = block.numT() > 0 ? T_ARG(0) : 0.01f;
//input->applyPairwiseTransform(pairwise::LRELUDerivativeE, epsilon, z, nullptr); //input->applyPairwiseTransform(pairwise::LRELUDerivativeE, epsilon, z, nullptr);
helpers::leakyReluDerivative(block.launchContext(), input, epsilon, z); helpers::leakyReluDerivative(block.launchContext(), input, epsilon, z, alpha);
return Status::OK(); return Status::OK();
} }

View File

@ -82,8 +82,8 @@ namespace nd4j {
* Math is: x < 0 ? alpha * x : x; * Math is: x < 0 ? alpha * x : x;
*/ */
#if NOT_EXCLUDED(OP_lrelu) #if NOT_EXCLUDED(OP_lrelu)
DECLARE_CONFIGURABLE_OP(lrelu, 1, 1, true, 0, 0); DECLARE_CONFIGURABLE_OP(lrelu, 1, 1, true, -2, 0);
DECLARE_CONFIGURABLE_OP(lrelu_bp, 2, 1, true, 0, 0); DECLARE_CONFIGURABLE_OP(lrelu_bp, 2, 1, true, -2, 0);
#endif #endif
/** /**
@ -91,8 +91,8 @@ namespace nd4j {
* Math is: x >= 0 ? x : exp(x) - 1; * Math is: x >= 0 ? x : exp(x) - 1;
*/ */
#if NOT_EXCLUDED(OP_elu) #if NOT_EXCLUDED(OP_elu)
DECLARE_CONFIGURABLE_OP(elu, 1, 1, true, 0, 0); DECLARE_CONFIGURABLE_OP(elu, 1, 1, true, -2, 0);
DECLARE_CONFIGURABLE_OP(elu_bp, 2, 1, true, 0, 0); DECLARE_CONFIGURABLE_OP(elu_bp, 2, 1, true, -2, 0);
#endif #endif
/** /**
@ -157,7 +157,7 @@ namespace nd4j {
/** /**
* This is Concatenated RELU implementation. * This is Concatenated RELU implementation.
* What happens inside: RELU(Concat((x, -x, {-1}))) * What happens inside: RELU(Concat((x, -x, {-1})))
* *
* PLEASE NOTE: Concatenation will double amount of features available in input * PLEASE NOTE: Concatenation will double amount of features available in input
*/ */
#if NOT_EXCLUDED(OP_crelu) #if NOT_EXCLUDED(OP_crelu)

View File

@ -81,29 +81,35 @@ namespace helpers {
} }
template <typename T> template <typename T>
static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { static void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) {
auto functor = LAMBDA_TT(x, y){
return x >= (T)0.f? y : T(0.f); const T alphaT = static_cast<T>(alpha);
auto functor = LAMBDA_TT(x, y, alphaT) {
return x < 0 ? alphaT * y : y;
}; };
input->applyPairwiseLambda<T>(epsilon, functor, output); input->applyPairwiseLambda<T>(epsilon, functor, output);
} }
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) {
BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES);
} }
template <typename T> template <typename T>
static void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { static void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) {
auto functor = LAMBDA_TT(x, y){
return y * nd4j::math::nd4j_eluderivative<T,T>(x); const T alphaT = static_cast<T>(alpha);
auto functor = LAMBDA_TT(x, y, alphaT){
return y * nd4j::math::nd4j_eluderivative<T,T>(x, alphaT);
}; };
input->applyPairwiseLambda<T>(epsilon, functor, output); input->applyPairwiseLambda<T>(epsilon, functor, output);
} }
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) {
BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES);
} }
template <typename T> template <typename T>

View File

@ -30,7 +30,7 @@ namespace nd4j {
shared[threadIdx.x] = 0; shared[threadIdx.x] = 0;
// each thread will compare 2 elements: E and E+1
for (int e = tid; e < length - 1; e += blockDim.x * gridDim.x) { for (int e = tid; e < length - 1; e += blockDim.x * gridDim.x) {
auto val0 = x[shape::getIndexOffset(e, xShapeInfo, length)]; auto val0 = x[shape::getIndexOffset(e, xShapeInfo, length)];
auto val1 = x[shape::getIndexOffset(e+1, xShapeInfo, length)]; auto val1 = x[shape::getIndexOffset(e+1, xShapeInfo, length)];
@ -41,11 +41,12 @@ namespace nd4j {
else else
v = val1 >= val0; v = val1 >= val0;
// store comparison result in shared memory
shared[threadIdx.x] += v ? 0 : 1; shared[threadIdx.x] += v ? 0 : 1;
} }
__syncthreads(); __syncthreads();
// aggregate sum // aggregate sums in shared memory
for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) { for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
if (threadIdx.x < activeThreads) if (threadIdx.x < activeThreads)
shared[threadIdx.x] += shared[threadIdx.x + activeThreads]; shared[threadIdx.x] += shared[threadIdx.x + activeThreads];
@ -53,7 +54,7 @@ namespace nd4j {
} }
// store over the grid // store over the grid if we have more than 1 block
if (gridDim.x > 1) { if (gridDim.x > 1) {
auto tc = reinterpret_cast<unsigned int *>(reductionBuffer); auto tc = reinterpret_cast<unsigned int *>(reductionBuffer);
@ -96,7 +97,7 @@ namespace nd4j {
} }
} }
else { else {
// if we have only 1 block, we just store results right away
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
auto tc = reinterpret_cast<unsigned int*>(reductionBuffer); auto tc = reinterpret_cast<unsigned int*>(reductionBuffer);
tc[16384] = 0; tc[16384] = 0;

View File

@ -424,7 +424,7 @@ static __global__ void avgPooling2dCuda(const void *vx, const Nd4jLong *xShapeIn
} }
__syncthreads(); __syncthreads();
int tid = blockIdx.x * gridDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int index = tid; index < length; index += blockDim.x * gridDim.x) { for (int index = tid; index < length; index += blockDim.x * gridDim.x) {
@ -519,7 +519,7 @@ static __global__ void pnormPooling2dCuda(const void *vx, const Nd4jLong *xShape
} }
__syncthreads(); __syncthreads();
int tid = blockIdx.x * gridDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int index = tid; index < length; index += blockDim.x * gridDim.x) { for (int index = tid; index < length; index += blockDim.x * gridDim.x) {
@ -610,7 +610,7 @@ static __global__ void maxPooling2dCuda(const void *vx, const Nd4jLong *xShapeIn
} }
__syncthreads(); __syncthreads();
int tid = blockIdx.x * gridDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
for (int index = tid; index < length; index += blockDim.x * gridDim.x) { for (int index = tid; index < length; index += blockDim.x * gridDim.x) {
@ -908,7 +908,7 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf
/*** max ***/ /*** max ***/
case 0: { case 0: {
coord2 = hstart; coord2 = hstart;
coord3 = hend; coord3 = wstart;
T max = -DataTypeUtils::max<T>(); T max = -DataTypeUtils::max<T>();
for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) { for (coords[2] = hstart; coords[2] < hend; coords[2] += dH) {
@ -923,8 +923,9 @@ __global__ static void pooling2dBPCuda(const void* vx, const Nd4jLong* xShapeInf
} }
coords[2] = coord2; coords[2] = coord2;
coords[3] = coord3; coords[3] = coord3;
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank)], y[yOffset]); auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
nd4j::math::atomics::nd4j_atomicAdd<T>(&z[zOffset], y[yOffset]);
//z[zOffset] += y[yOffset];
} }
break; break;
@ -987,7 +988,7 @@ void ConvolutionUtils::pooling2dBP(nd4j::graph::Context& block, const NDArray& i
PointersManager manager(block.launchContext(), "pooling2dBP"); PointersManager manager(block.launchContext(), "pooling2dBP");
const int threadsPerBlock = MAX_NUM_THREADS / 2; const int threadsPerBlock = 256;
const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; const int blocksPerGrid = (gradO.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128; const int sharedMem = gradO.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;

View File

@ -39,7 +39,7 @@ static __global__ void diagFunctorKernel(void* outputBuffer, Nd4jLong* outputSha
} }
__syncthreads(); __syncthreads();
const auto tid = blockIdx.x * gridDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x; const auto step = gridDim.x * blockDim.x;
for (int t = tid; t < inputLength; t += step) { for (int t = tid; t < inputLength; t += step) {
z[shape::getIndexOffset(t * (inputLength + 1), outputShape, outputLength)] = x[shape::getIndexOffset(t, inputShape, inputLength)]; //tX]; z[shape::getIndexOffset(t * (inputLength + 1), outputShape, outputLength)] = x[shape::getIndexOffset(t, inputShape, inputLength)]; //tX];
@ -59,7 +59,7 @@ static __global__ void diagFunctorKernel(void* outputBuffer, Nd4jLong* outputSha
} }
__syncthreads(); __syncthreads();
const auto tid = blockIdx.x * gridDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x; const auto step = gridDim.x * blockDim.x;
Nd4jLong i = threadIdx.x * (outputLength + 1); Nd4jLong i = threadIdx.x * (outputLength + 1);
for (int t = tid; t < outputLength && i < inputLength; t += step) { for (int t = tid; t < outputLength && i < inputLength; t += step) {

View File

@ -35,9 +35,11 @@ namespace helpers {
T const* input = reinterpret_cast<T const*>(inputBuf); T const* input = reinterpret_cast<T const*>(inputBuf);
T* output = reinterpret_cast<T*>(outputBuf); T* output = reinterpret_cast<T*>(outputBuf);
// trivial idea: loop through all elements, get independent probability for each element to be nullified
for (Nd4jLong e = 0; e < inLen; ++e) { for (Nd4jLong e = 0; e < inLen; ++e) {
T val = nodeRng->relativeT(e, T(0.f), T(1.f)); T val = nodeRng->relativeT(e, T(0.f), T(1.f));
// if probability is ok - we're saving scaled value
if (double(val) < probVal) if (double(val) < probVal)
output[shape::getIndexOffset(e, outputShape, inLen)] = T(input[shape::getIndexOffset(e, inputShape, inLen)] / probVal); output[shape::getIndexOffset(e, outputShape, inLen)] = T(input[shape::getIndexOffset(e, inputShape, inLen)] / probVal);
} }
@ -80,7 +82,7 @@ namespace helpers {
std::vector<Nd4jLong> dims(reduceShape->lengthOf()); std::vector<Nd4jLong> dims(reduceShape->lengthOf());
reduceShape->syncToHost(); // to ensure that follows are actual reduceShape->syncToHost(); // to ensure that follows are actual
bool fit = true; bool fit = true;
// PRAGMA_OMP_PARALLEL_FOR_ARGS(firstprivate(fit))
for( int i = 0; i < dims.size(); i++ ) { for( int i = 0; i < dims.size(); i++ ) {
if (fit) { if (fit) {
dims[i] = reduceShape->e<Nd4jLong>(i); dims[i] = reduceShape->e<Nd4jLong>(i);
@ -96,8 +98,7 @@ namespace helpers {
REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank.");
std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext())); std::unique_ptr<NDArray> chunk(new NDArray('c', dims, output->dataType(), context.launchContext()));
chunk->assign(1.f); chunk->assign(1.f);
//chunk->applyRandom<randomOps::DropOutInverted<T>>(rng, nullptr, chunk.get(), &probValue);
//NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob);
dropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), probValue, seed); dropoutSimple<T>(context.launchContext(), chunk.get(), chunk.get(), probValue, seed);
// broadcast chunk to full matrix // broadcast chunk to full matrix
std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input)); std::unique_ptr<NDArray> dropOutMultiplier(new NDArray(*input));
@ -105,6 +106,7 @@ namespace helpers {
*dropOutMultiplier += *chunk; *dropOutMultiplier += *chunk;
// FIXME: we could do this in one step, aren't we?
output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr);
} }
@ -113,8 +115,11 @@ namespace helpers {
int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
auto xType = input->dataType(); auto xType = input->dataType();
NDArray::prepareSpecialUse({output}, {input});
BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(xType, return _dropOutFunctor, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES);
NDArray::registerSpecialUse({output}, {input});
} }
/////////////////////////////////// backrpopagations /////////////////////////////////////////////// /////////////////////////////////// backrpopagations ///////////////////////////////////////////////
@ -136,6 +141,8 @@ namespace helpers {
for (int e = tid; e < len; e += step) { for (int e = tid; e < len; e += step) {
const auto zOffset = shape::getIndexOffset(e, outputShape, len); const auto zOffset = shape::getIndexOffset(e, outputShape, len);
// if probability was non-zero on FF step, we'll scale grads back
if (output[zOffset] != T(0.)) if (output[zOffset] != T(0.))
output[zOffset] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue); output[zOffset] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue);
@ -143,12 +150,17 @@ namespace helpers {
} }
template <typename T> template <typename T>
static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) {
// we're making additional FF run to see how probabilities played out with given seeds
int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue); int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue);
auto stream = context.launchContext()->getCudaStream(); auto stream = context.launchContext()->getCudaStream();
NDArray::prepareSpecialUse({output}, {input, gradOut});
if (ND4J_STATUS_OK == res) if (ND4J_STATUS_OK == res)
dropoutBPKernel<T><<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue); dropoutBPKernel<T><<<128, 256, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), gradOut->specialBuffer(), gradOut->specialShapeInfo(), probValue);
NDArray::registerSpecialUse({output}, {input, gradOut});
return res; return res;
} }
@ -239,6 +251,7 @@ namespace helpers {
int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta); int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta);
if (res == ND4J_STATUS_OK) { if (res == ND4J_STATUS_OK) {
// FIXME: can we make it single-loop?
(*output) *= alpha; (*output) *= alpha;
(*output) *= (*gradOut); //->applyPairwiseTransform<transform::Multiply>(gradOut, output, nullptr); (*output) *= (*gradOut); //->applyPairwiseTransform<transform::Multiply>(gradOut, output, nullptr);
} }

View File

@ -43,7 +43,7 @@ namespace nd4j {
} }
__syncthreads(); __syncthreads();
// we run things in blocks, 1 partition per block of threads
for (Nd4jLong o = blockIdx.x; o < numOutputs; o += gridDim.x) { for (Nd4jLong o = blockIdx.x; o < numOutputs; o += gridDim.x) {
auto z = reinterpret_cast<X*>(vz[o]); auto z = reinterpret_cast<X*>(vz[o]);
@ -89,9 +89,11 @@ namespace nd4j {
auto x = reinterpret_cast<X*>(vx); auto x = reinterpret_cast<X*>(vx);
auto indices = reinterpret_cast<Y*>(vindices); auto indices = reinterpret_cast<Y*>(vindices);
// we run things in blocks, 1 partition per block of threads
for (int i = blockIdx.x; i < numOutputs; i += gridDim.x) { for (int i = blockIdx.x; i < numOutputs; i += gridDim.x) {
auto z = reinterpret_cast<X*>(vz[i]); auto z = reinterpret_cast<X*>(vz[i]);
// each thread has own counter for partitions
int outCnt = 0; int outCnt = 0;
for (Nd4jLong e = 0; e < iLength; e++) { for (Nd4jLong e = 0; e < iLength; e++) {
@ -145,6 +147,7 @@ namespace nd4j {
tadOffsets[i] = packZ.platformOffsets(); tadOffsets[i] = packZ.platformOffsets();
} }
// we copy pointers to device
auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *))); auto dOutBuffers = reinterpret_cast<void **>(pm.replicatePointer(outBuffers.data(), outBuffers.size() * sizeof(void *)));
auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *))); auto dOutTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadShapes.data(), tadShapes.size() * sizeof(Nd4jLong *)));
auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *))); auto dOutTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(tadOffsets.data(), tadOffsets.size() * sizeof(Nd4jLong *)));
@ -248,6 +251,7 @@ namespace nd4j {
indicesShapes[e] = indices.at(e)->getSpecialShapeInfo(); indicesShapes[e] = indices.at(e)->getSpecialShapeInfo();
} }
// copying pointers to buffers to device
auto dInputBuffers = reinterpret_cast<void **>(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); auto dInputBuffers = reinterpret_cast<void **>(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *)));
auto dIndicesBuffers = reinterpret_cast<void **>(pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *))); auto dIndicesBuffers = reinterpret_cast<void **>(pm.replicatePointer(indicesBuffers.data(), inputSize * sizeof(void *)));
auto dInputShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(inputShapes.data(), inputSize * sizeof(Nd4jLong *))); auto dInputShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(inputShapes.data(), inputSize * sizeof(Nd4jLong *)));
@ -283,6 +287,7 @@ namespace nd4j {
inputTadOffsets[e] = packX.platformOffsets(); inputTadOffsets[e] = packX.platformOffsets();
} }
// copying pointers to buffers to device
auto dInputBuffers = reinterpret_cast<void **>(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *))); auto dInputBuffers = reinterpret_cast<void **>(pm.replicatePointer(inputBuffers.data(), inputSize * sizeof(void *)));
auto dInputTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(inputTadShapes.data(), inputSize * sizeof(Nd4jLong *))); auto dInputTadShapes = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(inputTadShapes.data(), inputSize * sizeof(Nd4jLong *)));
auto dInputTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(inputTadOffsets.data(), inputSize * sizeof(Nd4jLong *))); auto dInputTadOffsets = reinterpret_cast<Nd4jLong **>(pm.replicatePointer(inputTadOffsets.data(), inputSize * sizeof(Nd4jLong *)));
@ -313,6 +318,7 @@ namespace nd4j {
NDArray::registerSpecialUse({}, {indices, input}); NDArray::registerSpecialUse({}, {indices, input});
// TODO: it would be nice to have NDArray::registerSpecialUse signature that accepts something else beyond initializer_list
for (auto v:outputList) { for (auto v:outputList) {
v->tickWriteDevice(); v->tickWriteDevice();
} }

View File

@ -29,6 +29,7 @@ namespace nd4j {
Nd4jLong xCoord[MAX_RANK]; Nd4jLong xCoord[MAX_RANK];
// each block of threads works on 1 input array
for (Nd4jLong e = blockIdx.x; e < numInputs; e += gridDim.x) { for (Nd4jLong e = blockIdx.x; e < numInputs; e += gridDim.x) {
auto z = reinterpret_cast<T*>(zBuffer) + offsets[e]; auto z = reinterpret_cast<T*>(zBuffer) + offsets[e];
@ -39,6 +40,7 @@ namespace nd4j {
auto xRank = shape::rank(xShapeInfo); auto xRank = shape::rank(xShapeInfo);
auto xLength = shape::length(xShapeInfo); auto xLength = shape::length(xShapeInfo);
// each element of this input array has own place within common output array
for (uint i = threadIdx.x; i < xLength; i += blockDim.x) { for (uint i = threadIdx.x; i < xLength; i += blockDim.x) {
shape::index2coords(xRank, xShape, i, xLength, xCoord, order); shape::index2coords(xRank, xShape, i, xLength, xCoord, order);
auto xOffset = shape::getOffset(0, xShape, xStride, xCoord, xRank); auto xOffset = shape::getOffset(0, xShape, xStride, xCoord, xRank);
@ -65,6 +67,7 @@ namespace nd4j {
hdShapes[e] = inputs[e]->specialShapeInfo(); hdShapes[e] = inputs[e]->specialShapeInfo();
} }
// copying pointers to device
auto dBuffers = (void **) pm.replicatePointer(hdBuffers.data(), inputs.size() * sizeof(void*)); auto dBuffers = (void **) pm.replicatePointer(hdBuffers.data(), inputs.size() * sizeof(void*));
auto dShapes = (Nd4jLong **)pm.replicatePointer(hdShapes.data(), inputs.size() * sizeof(Nd4jLong*)); auto dShapes = (Nd4jLong **)pm.replicatePointer(hdShapes.data(), inputs.size() * sizeof(Nd4jLong*));
auto dOffsets = (Nd4jLong *) pm.replicatePointer(hOffsets.data(), inputs.size() * sizeof(Nd4jLong)); auto dOffsets = (Nd4jLong *) pm.replicatePointer(hOffsets.data(), inputs.size() * sizeof(Nd4jLong));
@ -76,6 +79,7 @@ namespace nd4j {
} }
void flatten(nd4j::LaunchContext *context, std::vector<NDArray*> &inputs, NDArray *output, char order) { void flatten(nd4j::LaunchContext *context, std::vector<NDArray*> &inputs, NDArray *output, char order) {
// FIXME: we want NDArrayFactory::prepareSpecialUse here eventually
for (auto v:inputs) for (auto v:inputs)
v->syncToDevice(); v->syncToDevice();

View File

@ -26,6 +26,7 @@ namespace ops {
namespace helpers { namespace helpers {
template <typename T> template <typename T>
void applyGradientDescent_(LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) { void applyGradientDescent_(LaunchContext* context, NDArray* input, NDArray* step, double weight, NDArray* output) {
// classic one
auto lambda = LAMBDA_TT(_x, _y, weight) { auto lambda = LAMBDA_TT(_x, _y, weight) {
return _x - (_y * weight); return _x - (_y * weight);
}; };

View File

@ -44,6 +44,7 @@ namespace nd4j {
X binSize = X((*max_val - *min_val) / numBins); X binSize = X((*max_val - *min_val) / numBins);
// nullify bins
for (int e = threadIdx.x; e < numBins; e += blockDim.x) { for (int e = threadIdx.x; e < numBins; e += blockDim.x) {
bins[e] = (Z) 0; bins[e] = (Z) 0;
} }
@ -53,14 +54,12 @@ namespace nd4j {
int idx = int((dx[e] - *min_val) / binSize); int idx = int((dx[e] - *min_val) / binSize);
idx = math::nd4j_max(idx, 0); //atomicMax(&idx, 0);//atomicMax(&idx, 0); idx = math::nd4j_max(idx, 0); //atomicMax(&idx, 0);//atomicMax(&idx, 0);
idx = math::nd4j_min(idx, int(numBins - 1)); //atomicMin(&idx, int(numBins - 1)); idx = math::nd4j_min(idx, int(numBins - 1)); //atomicMin(&idx, int(numBins - 1));
nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1); nd4j::math::atomics::nd4j_atomicAdd<Z>(&bins[idx], (Z)1);
// bins[idx]++;
} }
__syncthreads(); __syncthreads();
// at this point all bins in shared memory are calculated, so we aggregate them now via threadfence trick
// transfer shared memory to reduction memory // transfer shared memory to reduction memory
if (gridDim.x > 1) { if (gridDim.x > 1) {
unsigned int *tc = (unsigned int *)reductionPointer; unsigned int *tc = (unsigned int *)reductionPointer;
__shared__ bool amLast; __shared__ bool amLast;

View File

@ -64,6 +64,7 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray*
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), copy.data(), copy.size()); auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), copy.data(), copy.size());
// we launch legacy IndexMax op, to get indices of max values along dimension
auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions); auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions);
dim3 launchDims(256, 256, 16384); dim3 launchDims(256, 256, 16384);

View File

@ -66,29 +66,35 @@ namespace nd4j {
} }
template <typename T> template <typename T>
linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { linkage void leakyReluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) {
auto functor = LAMBDA_TT(x, y){
return x >= (T)0.f? y : T(0.f); const T alphaT = static_cast<T>(alpha);
auto functor = LAMBDA_TT(x, y, alphaT) {
return x < 0 ? alphaT * y : y;
}; };
input->applyPairwiseLambda(epsilon, functor, output); input->applyPairwiseLambda(epsilon, functor, output);
} }
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) {
BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(theFirst->dataType(), leakyReluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES);
} }
template <typename T> template <typename T>
linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output) { linkage void eluDerivative_(NDArray* input, NDArray* epsilon, NDArray* output, const float alpha) {
auto functor = LAMBDA_TT(x, y){
return y * nd4j::math::nd4j_eluderivative<T,T>(x); const T alphaT = static_cast<T>(alpha);
auto functor = LAMBDA_TT(x, y, alphaT){
return y * nd4j::math::nd4j_eluderivative<T,T>(x, alphaT);
}; };
input->applyPairwiseLambda(epsilon, functor, output); input->applyPairwiseLambda(epsilon, functor, output);
} }
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput) { void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha) {
BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(theFirst->dataType(), eluDerivative_, (theFirst, theSecond, theOutput, alpha), FLOAT_TYPES);
} }
template <typename T> template <typename T>

View File

@ -41,12 +41,12 @@ namespace helpers {
const T tbeta = static_cast<T>(beta); const T tbeta = static_cast<T>(beta);
const T talpha = static_cast<T>(alpha); const T talpha = static_cast<T>(alpha);
// one block of threads processes 1 example within batch
for (uint i = blockIdx.x; i < numTads; i += gridDim.x) { for (uint i = blockIdx.x; i < numTads; i += gridDim.x) {
auto x = reinterpret_cast<T*>(vx) + xTadOffsets[i]; auto x = reinterpret_cast<T*>(vx) + xTadOffsets[i];
auto z = reinterpret_cast<T*>(vz) + zTadOffsets[i]; auto z = reinterpret_cast<T*>(vz) + zTadOffsets[i];
// load everything into shared memory // load everything into shared memory, so we'll operate on shared memory from now on
shared[threadIdx.x] = x[threadIdx.x * xEws]; shared[threadIdx.x] = x[threadIdx.x * xEws];
__syncthreads(); __syncthreads();
@ -94,7 +94,7 @@ namespace helpers {
sharedY[threadIdx.x] = 0.f; sharedY[threadIdx.x] = 0.f;
__syncthreads(); __syncthreads();
// we're operating in shared memory
for (int s = begin; s < end; s++) for (int s = begin; s < end; s++)
sharedY[threadIdx.x] = sharedY[threadIdx.x] + sharedX[s] * sharedX[s]; sharedY[threadIdx.x] = sharedY[threadIdx.x] + sharedX[s] * sharedX[s];
__syncthreads(); __syncthreads();

View File

@ -37,7 +37,7 @@ namespace nd4j {
static __global__ void global_mergeMaxIndex_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) { static __global__ void global_mergeMaxIndex_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
auto output = reinterpret_cast<Z*>(voutput); auto output = reinterpret_cast<Z*>(voutput);
const auto tid = blockIdx.x * gridDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x; const auto step = gridDim.x * blockDim.x;
for (Nd4jLong e = tid; e < length; e += step) { for (Nd4jLong e = tid; e < length; e += step) {
@ -81,7 +81,13 @@ namespace nd4j {
} }
void mergeMaxIndex(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) { void mergeMaxIndex(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
NDArray::prepareSpecialUse({&output}, {});
for (auto v:inArrs)
v->syncToDevice();
BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (context, inArrs, output), LIBND4J_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({&output}, {});
} }
@ -90,7 +96,7 @@ namespace nd4j {
static __global__ void global_mergeMax_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) { static __global__ void global_mergeMax_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
auto output = reinterpret_cast<T*>(voutput); auto output = reinterpret_cast<T*>(voutput);
const auto tid = blockIdx.x * gridDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x; const auto step = gridDim.x * blockDim.x;
for (Nd4jLong e = tid; e < length; e += step) { for (Nd4jLong e = tid; e < length; e += step) {
@ -131,7 +137,12 @@ namespace nd4j {
} }
void mergeMax(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) { void mergeMax(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
NDArray::prepareSpecialUse({&output}, {});
for (auto v:inArrs)
v->syncToDevice();
BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(output.dataType(), mergeMax_, (context, inArrs, output), LIBND4J_TYPES);
NDArray::registerSpecialUse({&output}, {});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -139,7 +150,7 @@ namespace nd4j {
static __global__ void global_mergeAvg_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) { static __global__ void global_mergeAvg_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
auto output = reinterpret_cast<T*>(voutput); auto output = reinterpret_cast<T*>(voutput);
const auto tid = blockIdx.x * gridDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x; const auto step = gridDim.x * blockDim.x;
for (Nd4jLong e = tid; e < length; e += step) { for (Nd4jLong e = tid; e < length; e += step) {
@ -178,7 +189,13 @@ namespace nd4j {
} }
void mergeAvg(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) { void mergeAvg(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
NDArray::prepareSpecialUse({&output}, {});
for (auto v:inArrs)
v->syncToDevice();
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(output.dataType(), mergeAvg_, (context, inArrs, output), FLOAT_TYPES);
NDArray::registerSpecialUse({&output}, {});
} }
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
@ -186,7 +203,7 @@ namespace nd4j {
static __global__ void global_mergeAdd_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) { static __global__ void global_mergeAdd_(void **inArrs, void **inShapes, const int numArrays, void *voutput, Nd4jLong *outputShape, Nd4jLong length) {
auto output = reinterpret_cast<T*>(voutput); auto output = reinterpret_cast<T*>(voutput);
const auto tid = blockIdx.x * gridDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x; const auto step = gridDim.x * blockDim.x;
for (Nd4jLong e = tid; e < length; e += step) { for (Nd4jLong e = tid; e < length; e += step) {
@ -226,7 +243,13 @@ namespace nd4j {
BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), NUMERIC_TYPES); BUILD_SINGLE_TEMPLATE(template void mergeAdd_, (nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output), NUMERIC_TYPES);
void mergeAdd(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) { void mergeAdd(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output) {
NDArray::prepareSpecialUse({&output}, {});
for (auto v:inArrs)
v->syncToDevice();
BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), NUMERIC_TYPES); BUILD_SINGLE_SELECTOR(output.dataType(), mergeAdd_, (context, inArrs, output), NUMERIC_TYPES);
NDArray::registerSpecialUse({&output}, {});
} }
} }
} }

View File

@ -31,18 +31,18 @@ namespace helpers {
template <typename T> template <typename T>
static __global__ void fillUpElementKernel(void* outputBuffer, Nd4jLong* outputShapeInfo, void* inputBuffer, Nd4jLong* inputShapeInfo, Nd4jLong* pTadShape, Nd4jLong* pTadOffsets, Nd4jLong n) { static __global__ void fillUpElementKernel(void* outputBuffer, Nd4jLong* outputShapeInfo, void* inputBuffer, Nd4jLong* inputShapeInfo, Nd4jLong* pTadShape, Nd4jLong* pTadOffsets, Nd4jLong n) {
__shared__ T *z, *x;
__shared__ Nd4jLong bufferLength, arrLen; __shared__ Nd4jLong bufferLength, arrLen;
auto z = reinterpret_cast<T*>(outputBuffer);
auto x = reinterpret_cast<T*>(inputBuffer);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
z = reinterpret_cast<T*>(outputBuffer);
x = reinterpret_cast<T*>(inputBuffer);
arrLen = shape::length(pTadShape); arrLen = shape::length(pTadShape);
bufferLength = shape::length(outputShapeInfo); bufferLength = shape::length(outputShapeInfo);
} }
__syncthreads(); __syncthreads();
const auto tid = blockIdx.x * gridDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x; const auto step = gridDim.x * blockDim.x;
for (int t = tid; t < bufferLength; t += step) { for (int t = tid; t < bufferLength; t += step) {
auto tX = x + pTadOffsets[t]; auto tX = x + pTadOffsets[t];
@ -77,8 +77,6 @@ namespace helpers {
// manager.synchronize(); // manager.synchronize();
sortedVals.tickWriteDevice(); sortedVals.tickWriteDevice();
sortedVals.syncToHost(); sortedVals.syncToHost();
sortedVals.printIndexedBuffer("Hello");
sortedVals.printBuffer("Hello line");
auto stream = context->getCudaStream(); auto stream = context->getCudaStream();
fillUpElementKernel<T><<<32, 64, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), pTadShape, pTadOffsets, n); fillUpElementKernel<T><<<32, 64, 1024, *stream>>>(output->specialBuffer(), output->specialShapeInfo(), sortedVals.specialBuffer(), sortedVals.specialShapeInfo(), pTadShape, pTadOffsets, n);
} }

View File

@ -74,17 +74,14 @@ static void polyGammaCudaLauncher(const int blocksPerGrid, const int threadsPerB
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& z) { void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& z) {
if(!n.isActualOnDeviceSide()) n.syncToDevice(); NDArray::prepareSpecialUse({&z}, {&n, &x});
if(!x.isActualOnDeviceSide()) x.syncToDevice();
int threadsPerBlock = MAX_NUM_THREADS; int threadsPerBlock = MAX_NUM_THREADS;
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);
n.tickReadHost(); NDArray::registerSpecialUse({&z}, {&n, &x});
x.tickReadHost();
z.tickWriteDevice();
} }
BUILD_SINGLE_TEMPLATE(template void polyGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vn, const Nd4jLong *nShapeInfo, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES); BUILD_SINGLE_TEMPLATE(template void polyGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vn, const Nd4jLong *nShapeInfo, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES);

View File

@ -28,7 +28,7 @@ namespace helpers {
template <typename T> template <typename T>
static __global__ void global_range(void *output, Nd4jLong length, T start, T delta) { static __global__ void global_range(void *output, Nd4jLong length, T start, T delta) {
auto buff = reinterpret_cast<T*>(output); auto buff = reinterpret_cast<T*>(output);
const auto tid = blockIdx.x * gridDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x; const auto step = gridDim.x * blockDim.x;
for(Nd4jLong i = tid; i < length; i += step) for(Nd4jLong i = tid; i < length; i += step)
@ -43,10 +43,11 @@ namespace helpers {
} }
void range(nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) { void range(nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector) {
NDArray::prepareSpecialUse({&outVector}, {&start, &delta});
BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (context, start, delta, outVector), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(outVector.dataType(), _range, (context, start, delta, outVector), LIBND4J_TYPES);
NDArray::registerSpecialUse({&outVector}, {&start, &delta});
} }
BUILD_SINGLE_TEMPLATE(template void _range, (nd4j::LaunchContext * context, const NDArray& start, const NDArray& delta, NDArray& outVector), NUMERIC_TYPES);
} }
} }
} }

View File

@ -26,13 +26,11 @@ namespace nd4j {
namespace helpers { namespace helpers {
template<typename T> template<typename T>
void toggle_bits__(NDArray &in, NDArray &out) { void toggle_bits__(NDArray &in, NDArray &out) {
NDArray::prepareSpecialUse({&out}, {&in});
auto lambda = LAMBDA_T(_x) { auto lambda = LAMBDA_T(_x) {
return ~_x;//eUtils::flip_bits(_x); return ~_x;//eUtils::flip_bits(_x);
}; };
in.applyLambda(lambda, &out); in.applyLambda(lambda, &out);
NDArray::registerSpecialUse({&out}, {&in});
} }
BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES); BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES);

View File

@ -906,7 +906,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr
linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1; linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1;
} }
__syncthreads(); __syncthreads();
const auto tid = blockIdx.x * gridDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x; const auto step = gridDim.x * blockDim.x;
for (Nd4jLong e = tid; e < length; e += step) { for (Nd4jLong e = tid; e < length; e += step) {

View File

@ -46,8 +46,8 @@ namespace helpers {
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond);
void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void reluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void relu6Derivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void leakyReluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha);
void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void eluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput, const float alpha);
void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void seluDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void cubeDerivative(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);
void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput); void reduceNorm1(nd4j::LaunchContext * context, NDArray* theFirst, NDArray* theSecond, NDArray* theOutput);

View File

@ -2271,26 +2271,26 @@ namespace simdOps {
} }
}; };
template <typename X> template <typename X, typename Y, typename Z>
class ELU { class ELU {
public: public:
no_op_exec_special_same no_op_exec_special_same
no_op_exec_special_same_cuda no_op_exec_special_same_cuda
op_def static X op(X d1, X *params) { op_def static Z op(X d1, Y d2, Z *params) {
return nd4j::math::nd4j_elu<X,X>(d1); return nd4j::math::nd4j_elu<X,Z>(d1, static_cast<X>(d2));
} }
}; };
template <typename X> template <typename X, typename Y, typename Z>
class ELUDerivative { class ELUDerivative {
public: public:
no_op_exec_special_same no_op_exec_special_same
no_op_exec_special_same_cuda no_op_exec_special_same_cuda
op_def static X op(X d1, X *params) { op_def static Z op(X d1, Y d2, Z *params) {
return nd4j::math::nd4j_eluderivative<X,X>(d1); return nd4j::math::nd4j_eluderivative<X,Z>(d1, static_cast<X>(d2));
} }
}; };
@ -3716,7 +3716,7 @@ namespace simdOps {
return reduction; return reduction;
} }
op_def static Z op(X d1, X d2, Z *extraParamsRef) { op_def static Z op(X d1, X d2, Z *extraParamsRef) {
double eps = nd4j::math::nd4j_abs<double>(extraParamsRef[2]); double eps = nd4j::math::nd4j_abs<double>(extraParamsRef[2]);
return static_cast<Z>(!nd4j::math::nd4j_eq<X>(d1, d2, eps)); return static_cast<Z>(!nd4j::math::nd4j_eq<X>(d1, d2, eps));
} }
@ -4540,4 +4540,4 @@ namespace simdOps {
} }
#endif #endif

View File

@ -130,13 +130,12 @@ namespace nd4j {
} }
template<typename T, typename Z> template<typename T, typename Z>
math_def inline Z nd4j_elu(T val) { math_def inline Z nd4j_elu(T val, T alpha) {
if (val >= (T) 0.f) return val; if (val >= (T) 0.f)
else return nd4j_exp<T, Z>(val) - (Z) 1.0f; return val;
//return val >= 0.0 ? val : (nd4j_exp<T>(val) - 1.0); return static_cast<Z>(alpha) * (nd4j_exp<T, Z>(val) - static_cast<Z>(1.0f));
} }
template<typename T, typename Z> template<typename T, typename Z>
math_def inline Z nd4j_leakyrelu(T val,T alpha) { math_def inline Z nd4j_leakyrelu(T val,T alpha) {
if (val < (T) 0.0f) if (val < (T) 0.0f)
@ -145,13 +144,14 @@ namespace nd4j {
return val; return val;
} }
template<typename T, typename Z> template<typename T, typename Z>
math_def inline Z nd4j_eluderivative(T val) { math_def inline Z nd4j_eluderivative(T val, T alpha) {
if (val >= (T) 0.0f) return (Z) 1.0f; if (val >= static_cast<T>(0.0f))
else return nd4j_exp<T, Z>(val); return static_cast<Z>(1.0f);
return static_cast<Z>(alpha) * nd4j_exp<T, Z>(val);
//return val >= 0.0 ? 1.0 : nd4j_exp(val); //return val >= 0.0 ? 1.0 : nd4j_exp(val);
} }
template<typename T, typename Z> template<typename T, typename Z>
math_def inline Z nd4j_sin(T val); math_def inline Z nd4j_sin(T val);
@ -283,7 +283,7 @@ namespace nd4j {
#ifdef NATIVE_HALFS #ifdef NATIVE_HALFS
if (value < (float16) 0.f) { if (value < (float16) 0.f) {
return float16(__hneg(value.data)); return float16(__hneg(value.data));
} else } else
return value; return value;
#else #else
return (float16) fabsf((float) value); return (float16) fabsf((float) value);
@ -904,13 +904,13 @@ inline __device__ int16_t nd4j_atomicMax<int16_t>(int16_t* address, int16_t val)
template <> template <>
inline __device__ float16 nd4j_atomicMax<float16>(float16* address, float16 val) { inline __device__ float16 nd4j_atomicMax<float16>(float16* address, float16 val) {
int* address_as_ull = (int*) address; auto address_as_ull = (int*) address;
long addr = (long) address; long addr = (long) address;
bool misaligned = addr & 0x3; bool misaligned = addr & 0x3;
if (misaligned) if (misaligned)
address_as_ull = (int *) (addr - 2); address_as_ull = (int *) (address - 1);
PAIR old, assumed, fresh; PAIR old, assumed, fresh;
@ -937,13 +937,13 @@ inline __device__ float16 nd4j_atomicMax<float16>(float16* address, float16 val)
template <> template <>
inline __device__ bfloat16 nd4j_atomicMax<bfloat16>(bfloat16* address, bfloat16 val) { inline __device__ bfloat16 nd4j_atomicMax<bfloat16>(bfloat16* address, bfloat16 val) {
int* address_as_ull = (int*) address; auto address_as_ull = (int*) address;
long addr = (long)(address); long addr = (long)(address);
bool misaligned = addr & 0x3; bool misaligned = addr & 0x3;
if (misaligned) if (misaligned)
address_as_ull = (int *) (addr - 2); address_as_ull = (int *) (address - 1);
BPAIR old, assumed, fresh; BPAIR old, assumed, fresh;
@ -1060,13 +1060,13 @@ inline __device__ float16 nd4j_atomicAdd<float16>(float16* address, float16 val)
#if __CUDA_ARCH__ >= 700 #if __CUDA_ARCH__ >= 700
atomicAdd(reinterpret_cast<__half*>(address), val.data); atomicAdd(reinterpret_cast<__half*>(address), val.data);
#else #else
int* address_as_ull = (int*) address; auto address_as_ull = (int*) address;
long addr = (long) address; long addr = (long) address;
bool misaligned = addr & 0x3; bool misaligned = addr & 0x3;
if (misaligned) if (misaligned)
address_as_ull = (int *) (addr - 2); address_as_ull = (int *) (address - 1);
PAIR old, assumed, fresh; PAIR old, assumed, fresh;
@ -1094,13 +1094,13 @@ inline __device__ float16 nd4j_atomicAdd<float16>(float16* address, float16 val)
template <> template <>
inline __device__ bfloat16 nd4j_atomicAdd<bfloat16>(bfloat16* address, bfloat16 val) { inline __device__ bfloat16 nd4j_atomicAdd<bfloat16>(bfloat16* address, bfloat16 val) {
int* address_as_ull = (int*) address; auto address_as_ull = (int*) address;
long addr = (long)(address); auto addr = (long)(address);
bool misaligned = addr & 0x3; bool misaligned = addr & 0x3;
if (misaligned) if (misaligned)
address_as_ull = (int *) (addr - 2); address_as_ull = (int *) (address - 1);
BPAIR old, assumed, fresh; BPAIR old, assumed, fresh;
@ -1367,13 +1367,13 @@ inline __device__ Nd4jLong nd4j_atomicMul<Nd4jLong>(Nd4jLong* address, Nd4jLong
template <> template <>
inline __device__ bfloat16 nd4j_atomicMul<bfloat16>(bfloat16* address, bfloat16 val) { inline __device__ bfloat16 nd4j_atomicMul<bfloat16>(bfloat16* address, bfloat16 val) {
int* address_as_ull = (int*) address; auto address_as_ull = (int*) address;
long addr = (long)(address); long addr = (long)(address);
bool misaligned = addr & 0x3; bool misaligned = addr & 0x3;
if (misaligned) if (misaligned)
address_as_ull = (int *) (addr - 2); address_as_ull = (int *) (address - 1);
BPAIR old, assumed, fresh; BPAIR old, assumed, fresh;
@ -1400,13 +1400,13 @@ inline __device__ bfloat16 nd4j_atomicMul<bfloat16>(bfloat16* address, bfloat16
template <> template <>
inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val) { inline __device__ float16 nd4j_atomicMul<float16>(float16* address, float16 val) {
int* address_as_ull = (int*) address; auto address_as_ull = (int*) address;
long addr = (long)(address); long addr = (long)(address);
bool misaligned = addr & 0x3; bool misaligned = addr & 0x3;
if (misaligned) if (misaligned)
address_as_ull = (int *) (addr - 2); address_as_ull = (int *) (address - 1);
BPAIR old, assumed, fresh; BPAIR old, assumed, fresh;

View File

@ -905,6 +905,25 @@ TEST_F(DeclarableOpsTests12, softmax_9) {
delete arrF; delete arrF;
} }
TEST_F(DeclarableOpsTests12, maxpool_bp_half_1) {
auto x = NDArrayFactory::create<bfloat16>('c', {2, 3, 10, 1}, {0.2019043f, 0.6464844f, 0.9116211f, 0.60058594f, 0.34033203f, 0.7036133f, 0.6772461f, 0.3815918f, 0.87353516f, 0.04650879f, 0.67822266f, 0.8618164f, 0.88378906f, 0.7573242f, 0.66796875f, 0.63427734f, 0.33764648f, 0.46923828f, 0.62939453f, 0.76464844f, -0.8618164f, -0.94873047f, -0.9902344f, -0.88916016f, -0.86572266f, -0.92089844f, -0.90722656f, -0.96533203f, -0.97509766f, -0.4975586f, -0.84814453f, -0.984375f, -0.98828125f, -0.95458984f, -0.9472656f, -0.91064453f, -0.80859375f, -0.83496094f, -0.9140625f, -0.82470703f, 0.4802246f, 0.45361328f, 0.28125f, 0.28320312f, 0.79345703f, 0.44604492f, -0.30273438f, 0.11730957f, 0.56396484f, 0.73583984f, 0.1418457f, -0.44848633f, 0.6923828f, -0.40234375f, 0.40185547f, 0.48632812f, 0.14538574f, 0.4638672f, 0.13000488f, 0.5058594f});
auto y = NDArrayFactory::create<bfloat16>('c', {2, 3, 10, 1}, {0.0f, -0.13391113f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, -0.1751709f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.51904297f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5107422f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
auto z = NDArrayFactory::create<bfloat16>('c', {2, 3, 10, 1});
nd4j::ops::maxpool2d_bp op;
Context ctx(1);
Nd4jLong iArgs[] = {5,1,1, 2,2,0, 1,1,1, 0,0};
ctx.setIArguments(iArgs, 11);
ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo());
ctx.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
auto status = op.execute(&ctx);
ASSERT_EQ(Status::OK(), status);
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests12, lrn_bp_1) { TEST_F(DeclarableOpsTests12, lrn_bp_1) {

View File

@ -2794,53 +2794,42 @@ TEST_F(DeclarableOpsTests3, svd_test11) {
TEST_F(DeclarableOpsTests3, elu_test1) { TEST_F(DeclarableOpsTests3, elu_test1) {
auto x = NDArrayFactory::create<double>('c', {3,3}, {0.1, .2, .3, -.4,-.5,-.6, .7, .8, .9}); auto x = NDArrayFactory::create<double>('c', {3,3}, {0.1, .2, .3, -.4,-.5,-.6, .7, .8, .9});
// auto expS = NDArrayFactory::create<double>('c', {3}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9});
// auto expU = NDArrayFactory::create<double>('c', {3,3});
auto exp = NDArrayFactory::create<double>('c', {3,3}, {.1, .2, .3, -0.32968, -0.393469, -0.451188, .7, .8, .9});
nd4j::ops::elu op; nd4j::ops::elu op;
auto results = op.execute({&x}, {}, {}); auto results = op.execute({&x}, {0.5}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto s = results->at(0); auto s = results->at(0);
// auto u = results->at(1);
// auto v = results->at(2);
// s->printIndexedBuffer("ELU");
ASSERT_TRUE(exp.equalsTo(s)); ASSERT_TRUE(exp.equalsTo(s));
delete results; delete results;
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, elu_test2) { TEST_F(DeclarableOpsTests3, elu_bp_test1) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9}); auto x = NDArrayFactory::create<double>('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9});
auto eps = NDArrayFactory::create<double>('c', {3,3}); auto eps = NDArrayFactory::create<double>('c', {3,3});
eps.assign(2.); eps.assign(2.);
// auto expU = NDArrayFactory::create<double>('c', {3,3}); auto exp = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2});
auto exp = NDArrayFactory::create<double>('c', {3, 3}, {2, 2, 2, 1.34064, 1.213061, 1.097623, 2, 2, 2});
nd4j::ops::elu_bp op; nd4j::ops::elu_bp op;
auto results = op.execute({ &x, &eps }, {}, {}); auto results = op.execute({ &x, &eps }, {0.5}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto s = results->at(0); auto s = results->at(0);
// auto u = results->at(1); ASSERT_TRUE(exp.equalsTo(s));
// auto v = results->at(2);
// s->printIndexedBuffer("ELU_BP");
ASSERT_TRUE(exp.equalsTo(s));
delete results; delete results;
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, lrelu_test1) { TEST_F(DeclarableOpsTests3, lrelu_test1) {
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
// auto expS = NDArrayFactory::create<double>('c', {3});
// auto expU = NDArrayFactory::create<double>('c', {3,3});
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9});
nd4j::ops::lrelu op; nd4j::ops::lrelu op;
@ -2849,20 +2838,16 @@ TEST_F(DeclarableOpsTests3, lrelu_test1) {
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto s = results->at(0); auto s = results->at(0);
// auto u = results->at(1);
// auto v = results->at(2);
// s->printIndexedBuffer("LRELU");
ASSERT_TRUE(exp.equalsTo(s)); ASSERT_TRUE(exp.equalsTo(s));
delete results; delete results;
} }
TEST_F(DeclarableOpsTests3, lrelu_test2) { TEST_F(DeclarableOpsTests3, lrelu_bp_test1) {
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
// auto expS = NDArrayFactory::create<double>('c', {3});
auto eps = NDArrayFactory::create<double>('c', {3,3}, {2,2,2,2,2,2,2, 2,2}); auto eps = NDArrayFactory::create<double>('c', {3,3}, {2,2,2,2,2,2,2, 2,2});
auto exp = NDArrayFactory::create<double>('c', {3,3}, {2, 2, 2, 0, 0, 0, 2, 2, 2}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2});
nd4j::ops::lrelu_bp op; nd4j::ops::lrelu_bp op;
auto results = op.execute({&x, &eps}, {0.2}, {}); auto results = op.execute({&x, &eps}, {0.2}, {});
@ -2870,9 +2855,6 @@ TEST_F(DeclarableOpsTests3, lrelu_test2) {
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto s = results->at(0); auto s = results->at(0);
// auto u = results->at(1);
// auto v = results->at(2);
// s->printIndexedBuffer("LRELU_BP");
ASSERT_TRUE(exp.equalsTo(s)); ASSERT_TRUE(exp.equalsTo(s));
delete results; delete results;
@ -2882,8 +2864,6 @@ TEST_F(DeclarableOpsTests3, lrelu_test2) {
TEST_F(DeclarableOpsTests3, selu_test1) { TEST_F(DeclarableOpsTests3, selu_test1) {
auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); auto x = NDArrayFactory::create<double>('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9});
// auto expS = NDArrayFactory::create<double>('c', {3});
// auto expU = NDArrayFactory::create<double>('c', {3,3});
auto exp = NDArrayFactory::create<double>('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309}); auto exp = NDArrayFactory::create<double>('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309});
nd4j::ops::selu op; nd4j::ops::selu op;
@ -2892,7 +2872,6 @@ TEST_F(DeclarableOpsTests3, selu_test1) {
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto s = results->at(0); auto s = results->at(0);
// s->printIndexedBuffer("SELU");
ASSERT_TRUE(exp.equalsTo(s)); ASSERT_TRUE(exp.equalsTo(s));
delete results; delete results;

View File

@ -2761,7 +2761,7 @@ TEST_F(DeclarableOpsTests5, ELU_1) {
auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.}); auto exp = NDArrayFactory::create<double>('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.});
auto res = NDArrayFactory::create<double>('c', {2, 2, 2}); auto res = NDArrayFactory::create<double>('c', {2, 2, 2});
input.applyTransform(transform::ELU, &res); input.applyScalar(nd4j::scalar::ELU, 1.f, &res);
ASSERT_TRUE(res.equalsTo(&exp)); ASSERT_TRUE(res.equalsTo(&exp));
} }

View File

@ -204,20 +204,29 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum;
import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast; import org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast;
import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt;
import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker; import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LogSoftMaxDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And;
@ -1126,10 +1135,26 @@ public class DifferentialFunctionFactory {
return new org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative(sameDiff(), iX, wrt).outputVariable(); return new org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative(sameDiff(), iX, wrt).outputVariable();
} }
public SDVariable tanhRationalBp(SDVariable in, SDVariable epsilon) {
return new RationalTanhBp(sameDiff(), in, epsilon).outputVariable();
}
public SDVariable tanhRectifiedBp(SDVariable in, SDVariable epsilon) {
return new RectifiedTanhBp(sameDiff(), in, epsilon).outputVariable();
}
/**
* Use {@link #tanhRationalBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable tanhRationalDerivative(SDVariable in) { public SDVariable tanhRationalDerivative(SDVariable in) {
return new RationalTanhDerivative(sameDiff(), in, false).outputVariable(); return new RationalTanhDerivative(sameDiff(), in, false).outputVariable();
} }
/**
* Use {@link #tanhRectifiedBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable tanhRectifiedDerivative(SDVariable in) { public SDVariable tanhRectifiedDerivative(SDVariable in) {
return new RectifiedTanhDerivative(sameDiff(), in, false).outputVariable(); return new RectifiedTanhDerivative(sameDiff(), in, false).outputVariable();
} }
@ -1280,6 +1305,14 @@ public class DifferentialFunctionFactory {
return new Cube(sameDiff(), iX, false).outputVariable(); return new Cube(sameDiff(), iX, false).outputVariable();
} }
public SDVariable cubeBp(SDVariable in, SDVariable epsilon) {
return new CubeBp(sameDiff(), in, epsilon).outputVariable();
}
/**
* @deprecated Use {@link #cubeBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable cubeDerivative(SDVariable iX) { public SDVariable cubeDerivative(SDVariable iX) {
return new CubeDerivative(sameDiff(), iX, false).outputVariable(); return new CubeDerivative(sameDiff(), iX, false).outputVariable();
} }
@ -1329,6 +1362,14 @@ public class DifferentialFunctionFactory {
return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable(); return new RectifiedLinearDerivative(sameDiff(), input, grad).outputVariable();
} }
public SDVariable thresholdRelu(SDVariable in, SDVariable epsilon, double cutoff){
return new ThresholdRelu(sameDiff(), in, cutoff).outputVariable();
}
public SDVariable thresholdReluBp(SDVariable in, SDVariable epsilon, double cutoff){
return new ThresholdReluBp(sameDiff(), in, epsilon, cutoff).outputVariable();
}
public SDVariable relu6(SDVariable iX, double cutoff) { public SDVariable relu6(SDVariable iX, double cutoff) {
return new Relu6(sameDiff(), iX, false, cutoff).outputVariable(); return new Relu6(sameDiff(), iX, false, cutoff).outputVariable();
} }
@ -1350,6 +1391,14 @@ public class DifferentialFunctionFactory {
return new HardTanh(sameDiff(), iX, false).outputVariable(); return new HardTanh(sameDiff(), iX, false).outputVariable();
} }
public SDVariable hardTanhBp(SDVariable in, SDVariable epsilon) {
return new HardTanhBp(sameDiff(), in, epsilon).outputVariable();
}
/**
* @deprecated Use {@link #hardTanhBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable hardTanhDerivative(SDVariable iX) { public SDVariable hardTanhDerivative(SDVariable iX) {
return new HardTanhDerivative(sameDiff(), iX, false).outputVariable(); return new HardTanhDerivative(sameDiff(), iX, false).outputVariable();
} }
@ -1358,6 +1407,9 @@ public class DifferentialFunctionFactory {
return new HardSigmoid(sameDiff(), in, false).outputVariable(); return new HardSigmoid(sameDiff(), in, false).outputVariable();
} }
public SDVariable hardSigmoidBp(SDVariable in, SDVariable epsilon){
return new HardSigmoidBp(sameDiff(), in, epsilon).outputVariable();
}
public SDVariable sigmoid(SDVariable iX) { public SDVariable sigmoid(SDVariable iX) {
return new Sigmoid(sameDiff(), iX, false).outputVariable(); return new Sigmoid(sameDiff(), iX, false).outputVariable();
@ -1486,10 +1538,16 @@ public class DifferentialFunctionFactory {
} }
public SDVariable softsignBp(SDVariable in, SDVariable epsilon) {
return new SoftSignBp(sameDiff(), in, epsilon).outputVariable();
}
/**
* @deprecated Use {@link #softsignBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable softsignDerivative(SDVariable iX) { public SDVariable softsignDerivative(SDVariable iX) {
return new SoftSignDerivative(sameDiff(), iX, false).outputVariable(); return new SoftSignDerivative(sameDiff(), iX, false).outputVariable();
} }
@ -1500,14 +1558,12 @@ public class DifferentialFunctionFactory {
public SDVariable elu(SDVariable iX) { public SDVariable elu(SDVariable iX) {
return new ELU(sameDiff(), iX, false).outputVariable(); return new ELU(sameDiff(), iX).outputVariable();
} }
public SDVariable eluBp(SDVariable in, SDVariable epsilon) {
public SDVariable eluDerivative(SDVariable iX) { return new EluBp(sameDiff(), in, epsilon).outputVariable();
return new ELUDerivative(sameDiff(), iX, false).outputVariable();
} }
@ -1516,6 +1572,14 @@ public class DifferentialFunctionFactory {
} }
public SDVariable leakyReluBp(SDVariable in, SDVariable epsilon, double cutoff) {
return new LeakyReLUBp(sameDiff(), in, epsilon, cutoff).outputVariable();
}
/**
* @deprecated Use {@link #leakyReluBp(SDVariable, SDVariable, double)}
*/
@Deprecated
public SDVariable leakyReluDerivative(SDVariable iX, double cutoff) { public SDVariable leakyReluDerivative(SDVariable iX, double cutoff) {
return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable(); return new LeakyReLUDerivative(sameDiff(), iX, false, cutoff).outputVariable();
} }
@ -1832,7 +1896,15 @@ public class DifferentialFunctionFactory {
return new SELU(sameDiff(), arg, false).outputVariable(); return new SELU(sameDiff(), arg, false).outputVariable();
} }
public SDVariable seluBp(SDVariable in, SDVariable epsilon) {
validateDifferentialFunctionsameDiff(in);
return new SeluBp(sameDiff(), in, epsilon).outputVariable();
}
/**
* @deprecated Use {@link #seluBp(SDVariable, SDVariable)}
*/
@Deprecated
public SDVariable seluDerivative(SDVariable arg) { public SDVariable seluDerivative(SDVariable arg) {
validateDifferentialFunctionsameDiff(arg); validateDifferentialFunctionsameDiff(arg);
return new SELUDerivative(sameDiff(), arg, false).outputVariable(); return new SELUDerivative(sameDiff(), arg, false).outputVariable();

View File

@ -163,31 +163,6 @@ public class SDNN extends SDOps {
return updateVariableNameAndReference(result, name); return updateVariableNameAndReference(result, name);
} }
/**
* Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input.
* {@link #elu(SDVariable)}
*
* @param x Input variable
* @return Output variable
*/
public SDVariable eluDerivative(SDVariable x) {
return eluDerivative(null, x);
}
/**
* Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input.
* {@link #elu(SDVariable)}
*
* @param name Output variable name
* @param x Input variable
* @return Output variable
*/
public SDVariable eluDerivative(String name, SDVariable x) {
validateFloatingPoint("eluDerivative", x);
SDVariable result = f().eluDerivative(x);
return updateVariableNameAndReference(result, name);
}
/** /**
* GELU activation function - Gaussian Error Linear Units<br> * GELU activation function - Gaussian Error Linear Units<br>
* For more details, see <i>Gaussian Error Linear Units (GELUs)</i> - <a href="https://arxiv.org/abs/1606.08415">https://arxiv.org/abs/1606.08415</a> * For more details, see <i>Gaussian Error Linear Units (GELUs)</i> - <a href="https://arxiv.org/abs/1606.08415">https://arxiv.org/abs/1606.08415</a>

View File

@ -255,8 +255,6 @@ public class LegacyOpMapper {
return Abs.class; return Abs.class;
case 2: case 2:
return LogSoftMax.class; return LogSoftMax.class;
case 3:
return ELUDerivative.class;
case 4: case 4:
return org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class; return org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class;
case 5: case 5:

View File

@ -881,7 +881,6 @@ public class OpValidation {
SoftmaxBp.class, SoftmaxBp.class,
CubeDerivative.class, CubeDerivative.class,
ELUDerivative.class,
GELUDerivative.class, GELUDerivative.class,
PreciseGELUDerivative.class, PreciseGELUDerivative.class,
HardSigmoidDerivative.class, HardSigmoidDerivative.class,
@ -901,6 +900,17 @@ public class OpValidation {
TanhDerivative.class, TanhDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class,
PowDerivative.class, PowDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class,
BiasAddGrad.class, BiasAddGrad.class,
ConcatBp.class, ConcatBp.class,

View File

@ -229,6 +229,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class, org.nd4j.linalg.api.ops.impl.scalar.PowDerivative.class,
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class, org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear.class,
org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class, org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinearDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu.class,
org.nd4j.linalg.api.ops.impl.scalar.Relu6.class, org.nd4j.linalg.api.ops.impl.scalar.Relu6.class,
org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class, org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans.class,
org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class, org.nd4j.linalg.api.ops.impl.scalar.ScalarAdd.class,
@ -421,7 +422,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class, org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative.class,
@ -433,6 +433,17 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.ThresholdReluBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative.class, org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError.class,

View File

@ -21,6 +21,8 @@ import lombok.Getter;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
import org.nd4j.linalg.api.ops.impl.transforms.same.Cube; import org.nd4j.linalg.api.ops.impl.transforms.same.Cube;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -42,9 +44,9 @@ public class ActivationCube extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(@NonNull INDArray in, @NonNull INDArray epsilon) { public Pair<INDArray, INDArray> backprop(@NonNull INDArray in, @NonNull INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new CubeDerivative(in)); Nd4j.getExecutioner().execAndReturn(new CubeBp(in, epsilon, in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null); return new Pair<>(in, null);
} }
@Override @Override

View File

@ -18,11 +18,11 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU; import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions; import org.nd4j.linalg.indexing.conditions.Conditions;
@ -57,7 +57,7 @@ public class ActivationELU extends BaseActivationFunction {
public INDArray getActivation(INDArray in, boolean training) { public INDArray getActivation(INDArray in, boolean training) {
// no support in ELU native to override alpha // no support in ELU native to override alpha
if (this.alpha != 1.00) { if (this.alpha != 1.00) {
INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup())); INDArray alphaMultiple = Nd4j.getExecutioner().exec(new ELU(in.dup()))[0];
alphaMultiple.muli(alpha); alphaMultiple.muli(alpha);
BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0)); BooleanIndexing.replaceWhere(in, alphaMultiple, Conditions.lessThan(0));
} else { } else {
@ -74,21 +74,8 @@ public class ActivationELU extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
// no support in ELU native to override alpha Nd4j.getExecutioner().execAndReturn(new EluBp(in, epsilon, in));
if (alpha != 1.00) { return new Pair<>(in, null);
INDArray dLdz = Nd4j.getExecutioner().exec(new ELUDerivative(in.dup()));
dLdz.muli(alpha);
BooleanIndexing.replaceWhere(dLdz, 1, Conditions.equals(alpha));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
}
else {
INDArray dLdz = Nd4j.getExecutioner().exec(new ELUDerivative(in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null);
}
} }
@Override @Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,9 +43,9 @@ public class ActivationHardSigmoid extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new HardSigmoidDerivative(in)); Nd4j.getExecutioner().execAndReturn(new HardSigmoidBp(in, epsilon, in));
dLdz.muli(epsilon);
return new Pair<>(dLdz, null); return new Pair<>(in, null);
} }
@Override @Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhBp;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -43,9 +45,10 @@ public class ActivationHardTanH extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new HardTanhDerivative(in));
dLdz.muli(epsilon); Nd4j.getExecutioner().execAndReturn(new HardTanhBp(in, epsilon, in));
return new Pair<>(dLdz, null);
return new Pair<>(in, null);
} }
@Override @Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
@ -54,9 +56,10 @@ public class ActivationLReLU extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new LeakyReLUDerivative(in, alpha));
dLdz.muli(epsilon); Nd4j.getExecutioner().execAndReturn(new LeakyReLUBp(in, epsilon, in, alpha));
return new Pair<>(dLdz, null);
return new Pair<>(in, null);
} }
@Override @Override

View File

@ -63,7 +63,7 @@ public class ActivationRReLU extends BaseActivationFunction {
public INDArray getActivation(INDArray in, boolean training) { public INDArray getActivation(INDArray in, boolean training) {
if (training) { if (training) {
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
this.alpha = Nd4j.rand(in.shape(), l, u, Nd4j.getRandom()); this.alpha = Nd4j.rand(l, u, Nd4j.getRandom(), in.shape());
} }
INDArray inTimesAlpha = in.mul(alpha); INDArray inTimesAlpha = in.mul(alpha);
BooleanIndexing.replaceWhere(in, inTimesAlpha, Conditions.lessThan(0)); BooleanIndexing.replaceWhere(in, inTimesAlpha, Conditions.lessThan(0));

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RationalTanhBp;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -48,9 +50,10 @@ public class ActivationRationalTanh extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new RationalTanhDerivative(in));
dLdz.muli(epsilon); Nd4j.getExecutioner().execAndReturn(new RationalTanhBp(in, epsilon, in));
return new Pair<>(dLdz, null);
return new Pair<>(in, null);
} }
@Override @Override

View File

@ -20,8 +20,10 @@ import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.scalar.Relu6; import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
import org.nd4j.linalg.api.ops.impl.scalar.Step; import org.nd4j.linalg.api.ops.impl.scalar.Step;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.Relu6Derivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -41,9 +43,10 @@ public class ActivationReLU6 extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new Step(in));
dLdz.muli(epsilon); Nd4j.getExecutioner().execAndReturn(new Relu6Derivative(in, epsilon, in));
return new Pair<>(dLdz, null);
return new Pair<>(in, null);
} }
@Override @Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.RectifiedTanhBp;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -45,9 +47,10 @@ public class ActivationRectifiedTanh extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new RectifiedTanhDerivative(in));
dLdz.muli(epsilon); Nd4j.getExecutioner().execAndReturn(new RectifiedTanhBp(in, epsilon, in));
return new Pair<>(dLdz, null);
return new Pair<>(in, null);
} }
@Override @Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SeluBp;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,9 +43,10 @@ public class ActivationSELU extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new SELUDerivative(in));
dLdz.muli(epsilon); Nd4j.getExecutioner().execAndReturn(new SeluBp(in, epsilon, in));
return new Pair<>(dLdz, null);
return new Pair<>(in, null);
} }
@Override @Override

View File

@ -18,7 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,9 +42,10 @@ public class ActivationSigmoid extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new SigmoidDerivative(in));
dLdz.muli(epsilon); Nd4j.getExecutioner().execAndReturn(new SigmoidDerivative(in, epsilon, in));
return new Pair<>(dLdz, null);
return new Pair<>(in, null);
} }
@Override @Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftPlusBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid; import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
@ -41,9 +43,10 @@ public class ActivationSoftPlus extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new Sigmoid(in));
dLdz.muli(epsilon); Nd4j.getExecutioner().execAndReturn(new SoftPlusBp(in, epsilon, in));
return new Pair<>(dLdz, null);
return new Pair<>(in, null);
} }
@Override @Override

View File

@ -18,6 +18,8 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignBp;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,9 +43,10 @@ public class ActivationSoftSign extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new SoftSignDerivative(in));
dLdz.muli(epsilon); Nd4j.getExecutioner().execAndReturn(new SoftSignBp(in, epsilon, in));
return new Pair<>(dLdz, null);
return new Pair<>(in, null);
} }
@Override @Override

View File

@ -21,7 +21,9 @@ import lombok.Getter;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
@ -42,10 +44,10 @@ public class ActivationSoftmax extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray out = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(in, in.ulike()))[0];
INDArray x = out.mul(epsilon).sum(1); Nd4j.getExecutioner().execAndReturn(new SoftmaxBp(in, epsilon, in, -1));
INDArray dLdz = out.mul(epsilon.subColumnVector(x));
return new Pair<>(dLdz, null); return new Pair<>(in, null);
} }
@Override @Override

View File

@ -18,11 +18,11 @@ package org.nd4j.linalg.activations.impl;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import lombok.Getter; import lombok.Getter;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction; import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
/** /**
@ -41,9 +41,10 @@ public class ActivationTanH extends BaseActivationFunction {
@Override @Override
public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) { public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon); assertShape(in, epsilon);
INDArray dLdz = Nd4j.getExecutioner().exec(new TanhDerivative(in));
dLdz.muli(epsilon); Nd4j.getExecutioner().execAndReturn(new TanhDerivative(in, epsilon, in));
return new Pair<>(dLdz, null);
return new Pair<>(in, null);
} }
@Override @Override

View File

@ -2836,156 +2836,78 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return putScalar(i, element.getDouble(0)); return putScalar(i, element.getDouble(0));
} }
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray diviColumnVector(INDArray columnVector) { public INDArray diviColumnVector(INDArray columnVector) {
validateNumericalArray("diviColumnVector", false); validateNumericalArray("diviColumnVector", false);
return doColumnWise(columnVector, 'd'); return doColumnWise(columnVector, 'd');
} }
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray divColumnVector(INDArray columnVector) { public INDArray divColumnVector(INDArray columnVector) {
validateNumericalArray("divColumnVector", false); validateNumericalArray("divColumnVector", false);
return dup().diviColumnVector(columnVector); return dup().diviColumnVector(columnVector);
} }
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray diviRowVector(INDArray rowVector) { public INDArray diviRowVector(INDArray rowVector) {
validateNumericalArray("diviRowVector", false); validateNumericalArray("diviRowVector", false);
return doRowWise(rowVector, 'd'); return doRowWise(rowVector, 'd');
} }
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray divRowVector(INDArray rowVector) { public INDArray divRowVector(INDArray rowVector) {
validateNumericalArray("divRowVector", false); validateNumericalArray("divRowVector", false);
return dup().diviRowVector(rowVector); return dup().diviRowVector(rowVector);
} }
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray muliColumnVector(INDArray columnVector) { public INDArray muliColumnVector(INDArray columnVector) {
validateNumericalArray("muliColumnVector", false); validateNumericalArray("muliColumnVector", false);
return doColumnWise(columnVector, 'm'); return doColumnWise(columnVector, 'm');
} }
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray mulColumnVector(INDArray columnVector) { public INDArray mulColumnVector(INDArray columnVector) {
validateNumericalArray("mulColumnVector", false); validateNumericalArray("mulColumnVector", false);
return dup().muliColumnVector(columnVector); return dup().muliColumnVector(columnVector);
} }
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray muliRowVector(INDArray rowVector) { public INDArray muliRowVector(INDArray rowVector) {
validateNumericalArray("muliRowVector", false); validateNumericalArray("muliRowVector", false);
return doRowWise(rowVector, 'm'); return doRowWise(rowVector, 'm');
} }
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray mulRowVector(INDArray rowVector) { public INDArray mulRowVector(INDArray rowVector) {
validateNumericalArray("mulRowVector", false); validateNumericalArray("mulRowVector", false);
return dup().muliRowVector(rowVector); return dup().muliRowVector(rowVector);
} }
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray subiColumnVector(INDArray columnVector) { public INDArray subiColumnVector(INDArray columnVector) {
validateNumericalArray("subiColumnVector", false); validateNumericalArray("subiColumnVector", false);
return doColumnWise(columnVector, 's'); return doColumnWise(columnVector, 's');
} }
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray subColumnVector(INDArray columnVector) { public INDArray subColumnVector(INDArray columnVector) {
validateNumericalArray("subColumnVector", false); validateNumericalArray("subColumnVector", false);
return dup().subiColumnVector(columnVector); return dup().subiColumnVector(columnVector);
} }
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray subiRowVector(INDArray rowVector) { public INDArray subiRowVector(INDArray rowVector) {
validateNumericalArray("subiRowVector", false); validateNumericalArray("subiRowVector", false);
return doRowWise(rowVector, 's'); return doRowWise(rowVector, 's');
} }
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray subRowVector(INDArray rowVector) { public INDArray subRowVector(INDArray rowVector) {
validateNumericalArray("subRowVector", false); validateNumericalArray("subRowVector", false);
return dup().subiRowVector(rowVector); return dup().subiRowVector(rowVector);
} }
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray addiColumnVector(INDArray columnVector) { public INDArray addiColumnVector(INDArray columnVector) {
validateNumericalArray("addiColumnVector", false); validateNumericalArray("addiColumnVector", false);
@ -2997,24 +2919,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return doColumnWise(columnVector, 'p'); return doColumnWise(columnVector, 'p');
} }
/**
* In place addition of a column vector
*
* @param columnVector the column vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray addColumnVector(INDArray columnVector) { public INDArray addColumnVector(INDArray columnVector) {
validateNumericalArray("addColumnVector", false); validateNumericalArray("addColumnVector", false);
return dup().addiColumnVector(columnVector); return dup().addiColumnVector(columnVector);
} }
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray addiRowVector(INDArray rowVector) { public INDArray addiRowVector(INDArray rowVector) {
validateNumericalArray("addiRowVector", false); validateNumericalArray("addiRowVector", false);
@ -3027,47 +2937,22 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return doRowWise(rowVector, 'p'); return doRowWise(rowVector, 'p');
} }
/**
* In place addition of a column vector
*
* @param rowVector the row vector to add
* @return the result of the addition
*/
@Override @Override
public INDArray addRowVector(INDArray rowVector) { public INDArray addRowVector(INDArray rowVector) {
validateNumericalArray("addRowVector", false); validateNumericalArray("addRowVector", false);
return dup().addiRowVector(rowVector); return dup().addiRowVector(rowVector);
} }
/**
* Perform a copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
@Override @Override
public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) { public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) {
return mMulTranspose.exec(this, other, result); return mMulTranspose.exec(this, other, result);
} }
/**
* Perform a copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
@Override @Override
public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) { public INDArray mmul(INDArray other, MMulTranspose mMulTranspose) {
return mMulTranspose.exec(this, other, null); return mMulTranspose.exec(this, other, null);
} }
/**
* Perform a copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
@Override @Override
public INDArray mmul(INDArray other) { public INDArray mmul(INDArray other) {
Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType()); Preconditions.checkState(this.dataType() == other.dataType(), "Matrix multiplication: arrays must have same dtype: %s vs. %s", this.dataType(), other.dataType());
@ -3105,8 +2990,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
if(!isVectorOrScalar()) { if(!isVectorOrScalar()) {
throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this));
} }
return dup().data().asDouble(); return dup().data().asDouble();
} }
@ -3115,7 +2998,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
if(!isVectorOrScalar()) { if(!isVectorOrScalar()) {
throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this)); throw new ND4JIllegalStateException("Unable to create a 1d array from a non vector! Shape: " + Shape.shapeToStringShort(this));
} }
return dup().data().asFloat(); return dup().data().asFloat();
} }

View File

@ -1123,14 +1123,6 @@ public class BaseSparseNDArrayCOO extends BaseSparseNDArray {
return null; return null;
} }
/**
* Perform an copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @param result the result ndarray
* @param mMulTranspose the transpose status of each array
* @return the result of the matrix multiplication
*/
@Override @Override
public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) { public INDArray mmul(INDArray other, INDArray result, MMulTranspose mMulTranspose) {
return null; return null;

View File

@ -1,4 +1,4 @@
/******************************************************************************* /* *****************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
@ -16,9 +16,6 @@
package org.nd4j.linalg.api.ndarray; package org.nd4j.linalg.api.ndarray;
import static org.nd4j.linalg.factory.Nd4j.compressDebug;
import static org.nd4j.linalg.factory.Nd4j.preventUnpack;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.blas.params.MMulTranspose;
@ -52,6 +49,7 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
DataBuffer shapeInfoDataBuffer(); DataBuffer shapeInfoDataBuffer();
// TODO: Unused untested method.
/** /**
* Sparse info * Sparse info
* @return Sparse info. * @return Sparse info.
@ -110,12 +108,13 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
int elementWiseStride(); int elementWiseStride();
// TODO: Unused untested method.
/** /**
* Get a double at the given linear offset unsafe, without checks. * Get a double at the given linear offset unsafe, without checks.
* @param offset the offset to get at * @param offset the offset to get at
* @return double value at offset * @return double value at offset
*/ */
double getDoubleUnsafe(long offset); //TODO: consider deleting. double getDoubleUnsafe(long offset);
/** /**
* Get string value at given index. * Get string value at given index.
@ -124,13 +123,14 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
String getString(long index); String getString(long index);
// TODO: Unused untested method.
/** /**
* Insert a scalar at the given linear offset * Insert a scalar at the given linear offset
* @param offset the offset to insert at * @param offset the offset to insert at
* @param value the value to insert * @param value the value to insert
* @return this * @return this
*/ */
INDArray putScalarUnsafe(long offset, double value); //TODO: consider deleting. INDArray putScalarUnsafe(long offset, double value);
/** /**
* Returns the number of possible vectors for a given dimension * Returns the number of possible vectors for a given dimension
@ -190,6 +190,7 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray assign(INDArray arr); INDArray assign(INDArray arr);
// TODO: Unused untested method.
/** /**
* Assign all elements from given ndarray that are matching given condition, * Assign all elements from given ndarray that are matching given condition,
* ndarray to this ndarray * ndarray to this ndarray
@ -553,7 +554,7 @@ public interface INDArray extends Serializable, AutoCloseable {
* *
* @param n the number to subtract by * @param n the number to subtract by
* @param result the result ndarray * @param result the result ndarray
* @return * @return the result ndarray
*/ */
INDArray rsub(Number n, INDArray result); INDArray rsub(Number n, INDArray result);
@ -1041,7 +1042,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray divRowVector(INDArray rowVector); INDArray divRowVector(INDArray rowVector);
/** /**
* In place reverse divison of a column vector * In place reverse divison of a column vector
* *
@ -1066,6 +1066,7 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray rdiviRowVector(INDArray rowVector); INDArray rdiviRowVector(INDArray rowVector);
//TODO: unused / untested method.
/** /**
* Reverse division of a column vector (copy) * Reverse division of a column vector (copy)
* *
@ -1074,7 +1075,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray rdivRowVector(INDArray rowVector); INDArray rdivRowVector(INDArray rowVector);
/** /**
* In place multiplication of a column vector * In place multiplication of a column vector
* *
@ -1107,7 +1107,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray mulRowVector(INDArray rowVector); INDArray mulRowVector(INDArray rowVector);
/** /**
* In place reverse subtraction of a column vector * In place reverse subtraction of a column vector
* *
@ -1132,6 +1131,7 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray rsubiRowVector(INDArray rowVector); INDArray rsubiRowVector(INDArray rowVector);
//TODO: unused / untested method.
/** /**
* Reverse subtraction of a row vector (copy) * Reverse subtraction of a row vector (copy)
* *
@ -1180,7 +1180,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray addiColumnVector(INDArray columnVector); INDArray addiColumnVector(INDArray columnVector);
/** /**
* In place assignment of a column vector * In place assignment of a column vector
* *
@ -1221,6 +1220,12 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray addRowVector(INDArray rowVector); INDArray addRowVector(INDArray rowVector);
/**
* Perform a copy matrix multiplication
*
* @param other the other matrix to perform matrix multiply with
* @return the result of the matrix multiplication
*/
INDArray mmul(INDArray other, MMulTranspose mMulTranspose); INDArray mmul(INDArray other, MMulTranspose mMulTranspose);
/** /**
@ -1231,8 +1236,6 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
INDArray mmul(INDArray other); INDArray mmul(INDArray other);
/** /**
* Convert this ndarray to a 2d double matrix. * Convert this ndarray to a 2d double matrix.
* Note that THIS SHOULD NOT BE USED FOR SPEED. * Note that THIS SHOULD NOT BE USED FOR SPEED.
@ -1283,6 +1286,14 @@ public interface INDArray extends Serializable, AutoCloseable {
*/ */
int[] toIntVector(); int[] toIntVector();
/**
* Convert this ndarray to a 1d long matrix.
* Note that THIS SHOULD NOT BE USED FOR SPEED.
* This is mainly used for integrations with other libraries.
* Due to nd4j's off heap nature, moving data on heap is very expensive
* and should not be used if possible.
* @return a copy of this array as a 1d long array
*/
long[] toLongVector(); long[] toLongVector();
/** /**

View File

@ -16,17 +16,13 @@
package org.nd4j.linalg.api.ops.impl.scalar; package org.nd4j.linalg.api.ops.impl.scalar;
import org.nd4j.autodiff.samediff.SDVariable; import java.util.Collections;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.graph.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseScalarOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseScalarOp;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -108,8 +104,7 @@ public class LeakyReLU extends BaseScalarOp {
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().leakyReluDerivative(arg(), alpha).mul(i_v.get(0)); return Collections.singletonList(f().leakyReluBp(arg(), i_v.get(0), alpha));
return Arrays.asList(ret);
} }
@Override @Override

View File

@ -75,15 +75,8 @@ public class RectifiedLinear extends BaseScalarOp {
return "Relu"; return "Relu";
} }
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
if(scalarValue.getDouble(0) == 0.0){ return Collections.singletonList(f().thresholdReluBp(arg(), i_v.get(0), scalarValue.getDouble(0)));
return Collections.singletonList(f().reluDerivative(arg(), i_v.get(0)));
} else {
SDVariable step = new Step(sameDiff, arg(), false, scalarValue.getDouble(0)).outputVariables()[0];
SDVariable ret = step.mul(i_v.get(0));
return Collections.singletonList(ret);
}
} }
} }

View File

@ -3,10 +3,10 @@ package org.nd4j.linalg.api.ops.impl.scalar;
import lombok.NonNull; import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.shade.guava.base.Preconditions;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -30,7 +30,8 @@ public class RectifiedLinearDerivative extends DynamicCustomOp {
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) { public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes); Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0)); return Collections.singletonList(dataTypes.get(0));

View File

@ -0,0 +1,77 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import java.util.Collections;
import java.util.List;
import lombok.Getter;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
/**
* Threshold ReLU op. The genral case of {@link RectifiedLinear}.
*/
public class ThresholdRelu extends DynamicCustomOp {
@Getter
private double cutoff = 0.0;
public ThresholdRelu(){ }
public ThresholdRelu(SameDiff sd, SDVariable input, boolean inPlace, double cutoff){
super(sd, new SDVariable[]{input}, inPlace);
this.cutoff = cutoff;
addTArgument(cutoff);
}
public ThresholdRelu(SameDiff sd, SDVariable input, double cutoff){
super(sd, new SDVariable[]{input});
this.cutoff = cutoff;
addTArgument(cutoff);
}
public ThresholdRelu(@NonNull INDArray input, INDArray output, double cutoff){
super(new INDArray[]{input}, wrapOrNull(output));
this.cutoff = cutoff;
addTArgument(cutoff);
}
@Override
public String opName(){
return "thresholdedrelu";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return Collections.singletonList(f().thresholdReluBp(arg(), f1.get(0), cutoff));
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* Cube backpropagation op - dL/dIn from in and dL/dOut
*/
public class CubeBp extends DynamicCustomOp {
public CubeBp(){ }
public CubeBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public CubeBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "cube_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -27,7 +27,11 @@ import java.util.List;
/** /**
* Cube derivative, e.g. 3x^2 * Cube derivative, e.g. 3x^2
*
* @deprecated Use {@link CubeBp}
*
*/ */
@Deprecated
public class CubeDerivative extends BaseTransformStrictOp { public class CubeDerivative extends BaseTransformStrictOp {
public CubeDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { public CubeDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace); super(sameDiff, i_v, inPlace);

View File

@ -1,84 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
*
* Derivative of ELU: Exponential Linear Unit (alpha=1.0)<br>
* Introduced in paper:<br>
* Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)<br>
* Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter (2015)<br>
* <a href="http://arxiv.org/abs/1511.07289">http://arxiv.org/abs/1511.07289</a>
*
* @author Alex Black
*/
public class ELUDerivative extends BaseTransformStrictOp {
public ELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace);
}
public ELUDerivative() {
}
public ELUDerivative(INDArray x, INDArray z) {
super(x, z);
}
public ELUDerivative(INDArray x) {
super(x);
}
@Override
public int opNum() {
return 3;
}
@Override
public String opName() {
return "eluderivative";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = sameDiff.zerosLike(arg());
return Collections.singletonList(ret);
}
}

View File

@ -0,0 +1,67 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* ELU backpropagation op - dL/dIn from in and dL/dOut
*/
public class EluBp extends DynamicCustomOp {
public EluBp(){ }
public EluBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) {
this(input, gradient, output, 1.0);
}
public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double alpha){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
addTArgument(alpha);
}
@Override
public String opName(){
return "elu_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* Hard Sigmoid backpropagation op - dL/dIn from in and dL/dOut
*/
public class HardSigmoidBp extends DynamicCustomOp {
public HardSigmoidBp(){ }
public HardSigmoidBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public HardSigmoidBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "hardsigmoid_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -29,8 +29,11 @@ import java.util.List;
/** /**
* HardSigmoid derivative * HardSigmoid derivative
* *
* @deprecated Use {@link HardSigmoidBp}
*
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@Deprecated
public class HardSigmoidDerivative extends BaseTransformStrictOp { public class HardSigmoidDerivative extends BaseTransformStrictOp {
public HardSigmoidDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { public HardSigmoidDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace); super(sameDiff, i_v, inPlace);

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* Hard Tanh backpropagation op - dL/dIn from in and dL/dOut
*/
public class HardTanhBp extends DynamicCustomOp {
public HardTanhBp(){ }
public HardTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public HardTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "hardtanh_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -31,8 +31,11 @@ import java.util.List;
/** /**
* Hard tanh elementwise derivative function * Hard tanh elementwise derivative function
* *
* @deprecated Use {@link HardTanhBp}
*
* @author Adam Gibson * @author Adam Gibson
*/ */
@Deprecated
public class HardTanhDerivative extends BaseTransformStrictOp { public class HardTanhDerivative extends BaseTransformStrictOp {
public HardTanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { public HardTanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace); super(sameDiff, i_v, inPlace);

View File

@ -0,0 +1,68 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* LReLU backpropagation op - dL/dIn from in and dL/dOut
*/
public class LeakyReLUBp extends DynamicCustomOp {
public static final double DEFAULT_ALPHA = 0.01;
private double alpha = DEFAULT_ALPHA;
public LeakyReLUBp(){ }
public LeakyReLUBp(SameDiff sd, SDVariable input, SDVariable gradient, double alpha){
super(sd, new SDVariable[]{input, gradient});
this.alpha = alpha;
addTArgument(alpha);
}
public LeakyReLUBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double alpha){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
this.alpha = alpha;
addTArgument(alpha);
}
@Override
public String opName(){
return "lrelu_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* Rational Tanh backpropagation op - dL/dIn from in and dL/dOut
*/
public class RationalTanhBp extends DynamicCustomOp {
public RationalTanhBp(){ }
public RationalTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public RationalTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "rationaltanh_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -31,9 +31,12 @@ import java.util.List;
* Rational Tanh Derivative, as described at https://github.com/deeplearning4j/libnd4j/issues/351 * Rational Tanh Derivative, as described at https://github.com/deeplearning4j/libnd4j/issues/351
* Calculates dOut/dIn given input, not dL/dIn given dL/dOut and input * Calculates dOut/dIn given input, not dL/dIn given dL/dOut and input
* *
* @deprecated Use {@link RationalTanhBp}
*
* @author raver119@gmail.com * @author raver119@gmail.com
* @author AlexDBlack * @author AlexDBlack
*/ */
@Deprecated
public class RationalTanhDerivative extends BaseTransformStrictOp { public class RationalTanhDerivative extends BaseTransformStrictOp {
public RationalTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) { public RationalTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) {
super(sameDiff, in, inPlace); super(sameDiff, in, inPlace);

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* Rectified Tanh backpropagation op - dL/dIn from in and dL/dOut
*/
public class RectifiedTanhBp extends DynamicCustomOp {
public RectifiedTanhBp(){ }
public RectifiedTanhBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public RectifiedTanhBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "rectifiedtanh_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -30,9 +30,12 @@ import java.util.List;
/** /**
* Rectified Tanh Derivative * Rectified Tanh Derivative
* *
* @deprecated Use {@link RectifiedTanhBp}
*
* @author raver119@gmail.com * @author raver119@gmail.com
* @author AlexDBlack * @author AlexDBlack
*/ */
@Deprecated
public class RectifiedTanhDerivative extends BaseTransformStrictOp { public class RectifiedTanhDerivative extends BaseTransformStrictOp {
public RectifiedTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) { public RectifiedTanhDerivative(SameDiff sameDiff, SDVariable in, boolean inPlace) {
super(sameDiff, in, inPlace); super(sameDiff, in, inPlace);

View File

@ -16,15 +16,18 @@
package org.nd4j.linalg.api.ops.impl.transforms.gradient; package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import org.nd4j.linalg.api.ops.impl.transforms.same.Identity;
/** /**
* Derivative of Rectified linear unit 6, i.e. min(max(input, cutoff), 6), where cutoff can be chosen. * Derivative of Rectified linear unit 6, i.e. min(max(input, cutoff), 6), where cutoff can be chosen.
@ -33,7 +36,9 @@ import java.util.List;
*/ */
public class Relu6Derivative extends DynamicCustomOp { public class Relu6Derivative extends DynamicCustomOp {
private double cutoff = 0.0; private static final double DEFAULT_CUTOFF = 0.0;
private double cutoff = DEFAULT_CUTOFF;
public Relu6Derivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, double cutoff) { public Relu6Derivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, double cutoff) {
super("relu6_bp", sameDiff, new SDVariable[]{i_v1, i_v2}); super("relu6_bp", sameDiff, new SDVariable[]{i_v1, i_v2});
@ -45,6 +50,16 @@ public class Relu6Derivative extends DynamicCustomOp {
this.extraArgs = new Object[]{cutoff}; this.extraArgs = new Object[]{cutoff};
} }
public Relu6Derivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
this(input, gradient, output, DEFAULT_CUTOFF);
}
public Relu6Derivative(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double cutoff){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
this.cutoff = cutoff;
this.extraArgs = new Object[]{cutoff};
}
@Override @Override
public int opNum() { public int opNum() {
return 0; return 0;

View File

@ -31,8 +31,11 @@ import java.util.List;
* *
* https://arxiv.org/pdf/1706.02515.pdf * https://arxiv.org/pdf/1706.02515.pdf
* *
* @deprecated Use {@link SeluBp}
*
* @author raver119@gmail.com * @author raver119@gmail.com
*/ */
@Deprecated
public class SELUDerivative extends BaseTransformStrictOp { public class SELUDerivative extends BaseTransformStrictOp {
private static final double SELU_ALPHA = 1.6732632423543772848170429916717; private static final double SELU_ALPHA = 1.6732632423543772848170429916717;

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* SELU backpropagation op - dL/dIn from in and dL/dOut
*/
public class SeluBp extends DynamicCustomOp {
public SeluBp(){ }
public SeluBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public SeluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "selu_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* SoftPlus backpropagation op - dL/dIn from in and dL/dOut
*/
public class SoftPlusBp extends DynamicCustomOp {
public SoftPlusBp(){ }
public SoftPlusBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public SoftPlusBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "softplus_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -0,0 +1,63 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
/**
* SoftSign backpropagation op - dL/dIn from in and dL/dOut
*/
public class SoftSignBp extends DynamicCustomOp {
public SoftSignBp(){ }
public SoftSignBp(SameDiff sd, SDVariable input, SDVariable gradient){
super(sd, new SDVariable[]{input, gradient});
}
public SoftSignBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
}
@Override
public String opName(){
return "softsign_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -29,7 +29,10 @@ import java.util.List;
/** /**
* SoftSign derivative. * SoftSign derivative.
*
* @deprecated Use {@link SoftSignBp}
*/ */
@Deprecated
public class SoftSignDerivative extends BaseTransformStrictOp { public class SoftSignDerivative extends BaseTransformStrictOp {
public SoftSignDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { public SoftSignDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace); super(sameDiff, i_v, inPlace);

View File

@ -16,10 +16,12 @@
package org.nd4j.linalg.api.ops.impl.transforms.gradient; package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections; import java.util.Collections;
@ -40,6 +42,12 @@ public class SoftmaxBp extends DynamicCustomOp {
addIArgument(dimension); addIArgument(dimension);
} }
public SoftmaxBp(@NonNull INDArray input, @NonNull INDArray grad, INDArray output, Integer dimension){
super(new INDArray[]{input, grad}, wrapOrNull(output));
if(dimension != null)
addIArgument(dimension);
}
@Override @Override
public String opName() { public String opName() {
return "softmax_bp"; return "softmax_bp";

View File

@ -35,15 +35,15 @@ public class TanhDerivative extends DynamicCustomOp {
super(sameDiff, new SDVariable[]{i_v1, i_v2}); super(sameDiff, new SDVariable[]{i_v1, i_v2});
} }
public TanhDerivative(INDArray x, INDArray z) { public TanhDerivative(INDArray x, INDArray y, INDArray z) {
super(null, x, z, null, null); super(null, new INDArray[]{x, y}, new INDArray[]{z});
} }
public TanhDerivative() { public TanhDerivative() {
} }
public TanhDerivative(INDArray x) { public TanhDerivative(INDArray x, INDArray y) {
this(x, null); this(x, y, null);
} }
@Override @Override

View File

@ -0,0 +1,74 @@
/*
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.nd4j.linalg.api.ops.impl.transforms.gradient;
import java.util.Collections;
import java.util.List;
import lombok.Getter;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.api.ops.impl.transforms.custom.ThresholdRelu;
/**
* Threshold ReLU Backprop op - dL/dIn from in and dL/dOut
*
* For {@link RectifiedLinear} as well as {@link ThresholdRelu}.
*/
public class ThresholdReluBp extends DynamicCustomOp {
@Getter
private double cutoff = 0;
public ThresholdReluBp(){ }
public ThresholdReluBp(SameDiff sd, SDVariable input, SDVariable gradient, double cutoff){
super(sd, new SDVariable[]{input, gradient});
this.cutoff = cutoff;
addTArgument(cutoff);
}
public ThresholdReluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double cutoff){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
this.cutoff = cutoff;
addTArgument(cutoff);
}
@Override
public String opName(){
return "thresholdedrelu_bp";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions
.checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes);
Preconditions.checkArgument(dataTypes.get(0).isFPType() && dataTypes.get(1).isFPType(), "Input datatypes must be floating point, got %s", dataTypes);
return Collections.singletonList(dataTypes.get(0));
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
throw new UnsupportedOperationException("Not supported");
}
}

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.same; package org.nd4j.linalg.api.ops.impl.transforms.same;
import java.util.Collections;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
@ -70,7 +71,6 @@ public class Cube extends BaseTransformSameOp {
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> f1) { public List<SDVariable> doDiff(List<SDVariable> f1) {
SDVariable g = f().mul(f().cubeDerivative(arg()),f1.get(0)); return Collections.singletonList(f().cubeBp(arg(), f1.get(0)));
return Arrays.asList(g);
} }
} }

View File

@ -18,14 +18,18 @@ package org.nd4j.linalg.api.ops.impl.transforms.strict;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformFloatOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.BaseTransformOp; import org.tensorflow.framework.AttrValue;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Arrays; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* ELU: Exponential Linear Unit (alpha=1.0)<br> * ELU: Exponential Linear Unit (alpha=1.0)<br>
@ -36,25 +40,20 @@ import java.util.List;
* *
* @author Alex Black * @author Alex Black
*/ */
public class ELU extends BaseTransformStrictOp { public class ELU extends DynamicCustomOp {
public ELU(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { public ELU(SameDiff sameDiff, SDVariable i_v) {
super(sameDiff, i_v, inPlace); super(sameDiff, new SDVariable[]{i_v});
} }
public ELU() { public ELU() {
} }
public ELU(INDArray x, INDArray z) { public ELU(INDArray x, INDArray z) {
super(x, z); super(null, wrapOrNull(x), wrapOrNull(z));
} }
public ELU(INDArray x) { public ELU(INDArray x) {
super(x); this(x, null);
}
@Override
public int opNum() {
return 35;
} }
@Override @Override
@ -76,8 +75,14 @@ public class ELU extends BaseTransformStrictOp {
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
//ELU: e^x-1 if x<0, x otherwise //ELU: e^x-1 if x<0, x otherwise
//dL/dIn = dL/Out * dOut/dIn //dL/dIn = dL/Out * dOut/dIn
SDVariable ret = f().eluDerivative(arg()).mul(i_v.get(0)); return Collections.singletonList(f().eluBp(arg(), i_v.get(0)));
return Arrays.asList(ret);
} }
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes) {
Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 datatype for ELU, got %s", dataTypes);
Preconditions.checkState(dataTypes.get(0).isFPType(), "Expected floating point input type for ELU, got %s", dataTypes);
return dataTypes;
}
} }

View File

@ -69,9 +69,7 @@ public class HardSigmoid extends BaseTransformStrictOp {
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> f1) { public List<SDVariable> doDiff(List<SDVariable> f1) {
SDVariable in = arg(); return Collections.singletonList(f().hardSigmoidBp(arg(), f1.get(0)));
SDVariable dOutdIn = new HardSigmoidDerivative(sameDiff, in, false).outputVariables()[0];
return Collections.singletonList(dOutdIn.mul(f1.get(0)));
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict; package org.nd4j.linalg.api.ops.impl.transforms.strict;
import java.util.Collections;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -70,7 +71,6 @@ public class HardTanh extends BaseTransformStrictOp {
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().hardTanhDerivative(arg()).mul(i_v.get(0)); return Collections.singletonList(f().hardTanhBp(arg(), i_v.get(0)));
return Arrays.asList(ret);
} }
} }

View File

@ -16,13 +16,10 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict; package org.nd4j.linalg.api.ops.impl.transforms.strict;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
import java.util.Collections; import java.util.Collections;
@ -71,6 +68,6 @@ public class RationalTanh extends BaseTransformStrictOp {
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> f1) { public List<SDVariable> doDiff(List<SDVariable> f1) {
return Collections.singletonList(f().tanhRationalDerivative(arg()).mul(f1.get(0))); return Collections.singletonList(f().tanhRationalBp(arg(), f1.get(0)));
} }
} }

View File

@ -17,13 +17,10 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict; package org.nd4j.linalg.api.ops.impl.transforms.strict;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseTransformFloatOp;
import org.nd4j.linalg.api.ops.BaseTransformOp;
import org.nd4j.linalg.api.ops.BaseTransformStrictOp; import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -88,6 +85,6 @@ public class RectifiedTanh extends BaseTransformStrictOp {
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> f1) { public List<SDVariable> doDiff(List<SDVariable> f1) {
return Collections.singletonList(f().tanhRectifiedDerivative(arg()).mul(f1.get(0))); return Collections.singletonList(f().tanhRectifiedBp(arg(), f1.get(0)));
} }
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict; package org.nd4j.linalg.api.ops.impl.transforms.strict;
import java.util.Collections;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -76,8 +77,7 @@ public class SELU extends BaseTransformStrictOp {
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().seluDerivative(arg()).mul(i_v.get(0)); return Collections.singletonList(f().seluBp(arg(), i_v.get(0)));
return Arrays.asList(ret);
} }
} }

View File

@ -28,8 +28,11 @@ import java.util.List;
/** /**
* Sigmoid derivative * Sigmoid derivative
* *
* @deprecated Use {@link org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative}
*
* @author Adam Gibson * @author Adam Gibson
*/ */
@Deprecated
public class SigmoidDerivative extends BaseTransformStrictOp { public class SigmoidDerivative extends BaseTransformStrictOp {
public SigmoidDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { public SigmoidDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) {
super(sameDiff, i_v1, i_v2); super(sameDiff, i_v1, i_v2);

View File

@ -16,6 +16,7 @@
package org.nd4j.linalg.api.ops.impl.transforms.strict; package org.nd4j.linalg.api.ops.impl.transforms.strict;
import java.util.Collections;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -73,8 +74,7 @@ public class SoftSign extends BaseTransformStrictOp {
@Override @Override
public List<SDVariable> doDiff(List<SDVariable> i_v) { public List<SDVariable> doDiff(List<SDVariable> i_v) {
SDVariable ret = f().softsignDerivative(arg()).mul(i_v.get(0)); return Collections.singletonList(f().softsignBp(arg(), i_v.get(0)));
return Arrays.asList(ret);
} }
} }

View File

@ -27,7 +27,10 @@ import java.util.List;
/** /**
* Tanh derivative * Tanh derivative
*
* @deprecated Use {@link org.nd4j.linalg.api.ops.impl.transforms.gradient.TanhDerivative}.
*/ */
@Deprecated
public class TanhDerivative extends BaseTransformStrictOp { public class TanhDerivative extends BaseTransformStrictOp {
public TanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { public TanhDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
super(sameDiff, i_v, inPlace); super(sameDiff, i_v, inPlace);

View File

@ -37,7 +37,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.floating.*; import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*; import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
@ -438,16 +438,16 @@ public class Transforms {
public static INDArray elu(INDArray in, boolean copy) { public static INDArray elu(INDArray in, boolean copy) {
return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in))); return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)))[0];
} }
public static INDArray eluDerivative(INDArray arr) { public static INDArray eluDerivative(INDArray arr, INDArray grad) {
return eluDerivative(arr, true); return eluDerivative(arr, grad,true);
} }
public static INDArray eluDerivative(INDArray in, boolean copy) { public static INDArray eluDerivative(INDArray in, INDArray grad, boolean copy) {
return Nd4j.getExecutioner().exec(new ELUDerivative(in, (copy ? in.ulike() : in))); return Nd4j.getExecutioner().exec(new EluBp(in, grad, (copy ? in.ulike() : in)))[0];
} }

View File

@ -42,7 +42,7 @@ public class LecunUniformInitScheme extends BaseWeightInitScheme {
@Override @Override
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
double b = 3.0 / Math.sqrt(fanIn); double b = 3.0 / Math.sqrt(fanIn);
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-b, b)); return Nd4j.rand(Nd4j.getDistributions().createUniform(-b, b), shape);
} }

View File

@ -43,7 +43,7 @@ public class ReluUniformInitScheme extends BaseWeightInitScheme {
@Override @Override
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
double u = Math.sqrt(6.0 / fanIn); double u = Math.sqrt(6.0 / fanIn);
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-u, u)); //U(-sqrt(6/fanIn), sqrt(6/fanIn) return Nd4j.rand(Nd4j.getDistributions().createUniform(-u, u), shape); //U(-sqrt(6/fanIn), sqrt(6/fanIn)
} }

View File

@ -46,7 +46,7 @@ public class SigmoidUniformInitScheme extends BaseWeightInitScheme {
@Override @Override
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut)); double r = 4.0 * Math.sqrt(6.0 / (fanIn + fanOut));
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-r, r)); return Nd4j.rand(Nd4j.getDistributions().createUniform(-r, r), shape);
} }

View File

@ -43,7 +43,7 @@ public class UniformInitScheme extends BaseWeightInitScheme {
@Override @Override
public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) { public INDArray doCreate(DataType dataType, long[] shape, INDArray paramsView) {
double a = 1.0 / Math.sqrt(fanIn); double a = 1.0 / Math.sqrt(fanIn);
return Nd4j.rand(shape, Nd4j.getDistributions().createUniform(-a, a)); return Nd4j.rand(Nd4j.getDistributions().createUniform(-a, a), shape);
} }

Some files were not shown because too many files have changed in this diff Show More