diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java index 24754d50b..417154cf2 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java @@ -380,19 +380,23 @@ public class VPTree implements Serializable { private Node buildFromPoints(INDArray items) { if (executorService == null && items == this.items && workers > 1) { + final val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() { @Override - public Thread newThread(Runnable r) { - Thread t = Executors.defaultThreadFactory().newThread(r); + public Thread newThread(final Runnable r) { + Thread t = new Thread(new Runnable() { + + @Override + public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); + r.run(); + } + }); t.setDaemon(true); t.setName("VPTree thread"); - // we don't want threads to be working on different devices - Nd4j.getAffinityManager().attachThreadToDevice(t, - Nd4j.getAffinityManager().getDeviceForCurrentThread()); - return t; } }); diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java index 0e2fd339a..d5afd549b 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelInference.java @@ -132,9 +132,8 @@ public class ParallelInference { boolean cRoot = !assignedRoot.get() && cDevice == currentDevice; assignedRoot.compareAndSet(false, cRoot); - zoo[i] = new InferenceWorker(i, model, observables, cRoot); + zoo[i] = new InferenceWorker(i, model, observables, cRoot, cDevice); - Nd4j.getAffinityManager().attachThreadToDevice(zoo[i], cDevice); zoo[i].setDaemon(true); zoo[i].start(); } @@ -425,13 +424,15 @@ public class ParallelInference { private Model replicatedModel; private AtomicLong counter = new AtomicLong(0); private boolean rootDevice; + private int deviceId; private ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock(); - private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice) { + private InferenceWorker(int id, @NonNull Model model, @NonNull BlockingQueue inputQueue, boolean rootDevice, int deviceId) { this.inputQueue = inputQueue; this.protoModel = model; this.rootDevice = rootDevice; + this.deviceId = deviceId; this.setDaemon(true); this.setName("InferenceThread-" + id); @@ -491,6 +492,7 @@ public class ParallelInference { @Override public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); try { // model should be replicated & initialized here initializeReplicaModel(); diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java index 4c8320a9d..0b4f29de2 100644 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java +++ b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-scaleout-parallelwrapper/src/main/java/org/deeplearning4j/parallelism/ParallelWrapper.java @@ -151,18 +151,21 @@ public class ParallelWrapper implements AutoCloseable { workerCounter.set(0); this.executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(workers, new ThreadFactory() { @Override - public Thread newThread(@NonNull Runnable r) { - Thread t = Executors.defaultThreadFactory().newThread(r); + public Thread newThread(@NonNull final Runnable r) { + final int cThread = workerCounter.getAndIncrement(); - int cThread = workerCounter.getAndIncrement(); + Thread t = new Thread(new Runnable() { + @Override + public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(cThread % Nd4j.getAffinityManager().getNumberOfDevices()); + r.run(); + } + }); t.setName("ParallelWrapper training thread " + cThread); t.setDaemon(true); t.setUncaughtExceptionHandler(handler); - Nd4j.getAffinityManager().attachThreadToDevice(t, - cThread % Nd4j.getAffinityManager().getNumberOfDevices()); - return t; } }); diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java index 2ae1c6f23..45eca7327 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformer.java @@ -80,14 +80,9 @@ public class Word2VecPerformer implements VoidFunction, Ato initExpTable(); if (negative > 0 && conf.contains(Word2VecVariables.TABLE)) { - try { - ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(Word2VecVariables.TABLE).getBytes()); - DataInputStream dis = new DataInputStream(bis); - table = Nd4j.read(dis); - } catch (IOException e) { - e.printStackTrace(); - } - + ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(Word2VecVariables.TABLE).getBytes()); + DataInputStream dis = new DataInputStream(bis); + table = Nd4j.read(dis); } } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java index 4d182b90f..539755ee6 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-nlp/src/main/java/org/deeplearning4j/spark/models/embeddings/word2vec/Word2VecPerformerVoid.java @@ -95,16 +95,10 @@ public class Word2VecPerformerVoid implements VoidFunction, initExpTable(); if (negative > 0 && conf.contains(TABLE)) { - try { - ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(TABLE).getBytes()); - DataInputStream dis = new DataInputStream(bis); - table = Nd4j.read(dis); - } catch (IOException e) { - e.printStackTrace(); - } - + ByteArrayInputStream bis = new ByteArrayInputStream(conf.get(TABLE).getBytes()); + DataInputStream dis = new DataInputStream(bis); + table = Nd4j.read(dis); } - } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java index cbca7d52b..8cadbea43 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/main/java/org/deeplearning4j/spark/parameterserver/training/SharedTrainingWorker.java @@ -86,7 +86,7 @@ public class SharedTrainingWorker extends BaseTrainingWorkerdevice affinity, as master thread */ - Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId); thread.setDaemon(true); thread.start(); } @@ -116,9 +115,8 @@ public class SparkADSI extends AsyncDataSetIterator { public class SparkPrefetchThread extends AsyncPrefetchThread { - protected SparkPrefetchThread(BlockingQueue queue, DataSetIterator iterator, DataSet terminator, - MemoryWorkspace workspace) { - super(queue, iterator, terminator, workspace); + protected SparkPrefetchThread(BlockingQueue queue, DataSetIterator iterator, DataSet terminator, MemoryWorkspace workspace, int deviceId) { + super(queue, iterator, terminator, workspace, deviceId); } diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java index fabec0587..44b8d3ee1 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark/src/main/java/org/deeplearning4j/spark/iterator/SparkAMDSI.java @@ -97,15 +97,10 @@ public class SparkAMDSI extends AsyncMultiDataSetIterator { if (iterator.resetSupported()) this.backedIterator.reset(); - this.thread = new SparkPrefetchThread(buffer, iterator, terminator); + this.thread = new SparkPrefetchThread(buffer, iterator, terminator, Nd4j.getAffinityManager().getDeviceForCurrentThread()); context = TaskContext.get(); - /** - * We want to ensure, that background thread will have the same thread->device affinity, as master thread - */ - Nd4j.getAffinityManager().attachThreadToDevice(thread, deviceId); - thread.setDaemon(true); thread.start(); } @@ -117,9 +112,8 @@ public class SparkAMDSI extends AsyncMultiDataSetIterator { protected class SparkPrefetchThread extends AsyncPrefetchThread { - protected SparkPrefetchThread(@NonNull BlockingQueue queue, - @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator) { - super(queue, iterator, terminator); + protected SparkPrefetchThread(@NonNull BlockingQueue queue, @NonNull MultiDataSetIterator iterator, @NonNull MultiDataSet terminator, int deviceId) { + super(queue, iterator, terminator, deviceId); } } } diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index 726549415..1404afc96 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -2462,7 +2462,7 @@ double NDArray::getTrace() const { double sum = 0.; -PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) if(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) +PRAGMA_OMP_PARALLEL_FOR_ARGS(reduction(OMP_SUMT:sum) OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) for(int i = 0; i < minDim; ++i) sum += e(i * offset); diff --git a/libnd4j/blas/cpu/NDArray.cpp b/libnd4j/blas/cpu/NDArray.cpp index a79f81612..2a843f956 100644 --- a/libnd4j/blas/cpu/NDArray.cpp +++ b/libnd4j/blas/cpu/NDArray.cpp @@ -100,7 +100,7 @@ void NDArray::fillAsTriangular(const float val, int lower, int upper, const char std::vector coords(zRank); - PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords)) + PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords)) for (Nd4jLong i = 0; i < zLen; ++i) { shape::index2coords(zRank, target->shapeOf(), i, zLen, coords.data()); @@ -141,7 +141,7 @@ void NDArray::setIdentity() { minDim = shape[i]; float v = 1.0f; - PRAGMA_OMP_PARALLEL_FOR_ARGS(if(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) + PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(minDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) for(int i = 0; i < minDim; ++i) templatedSet(buffer(), i*offset, this->dataType(), &v); } diff --git a/libnd4j/blas/cpu/NDArrayFactory.cpp b/libnd4j/blas/cpu/NDArrayFactory.cpp index ec99ef7db..d8b686b12 100644 --- a/libnd4j/blas/cpu/NDArrayFactory.cpp +++ b/libnd4j/blas/cpu/NDArrayFactory.cpp @@ -172,7 +172,9 @@ template void NDArrayFactory::memcpyFromVector(void *ptr, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); + template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); template NDArray NDArrayFactory::create(const char order, const std::vector &shape, const std::initializer_list& data, nd4j::LaunchContext * context); diff --git a/libnd4j/blas/cuda/NativeOpExecutioner.cu b/libnd4j/blas/cuda/NativeOpExecutioner.cu index 69f15c69b..b3573c7ab 100644 --- a/libnd4j/blas/cuda/NativeOpExecutioner.cu +++ b/libnd4j/blas/cuda/NativeOpExecutioner.cu @@ -785,48 +785,9 @@ void NativeOpExecutioner::execTransformAny(nd4j::LaunchContext *lc, auto xType = ArrayOptions::dataType(hXShapeInfo); auto zType = ArrayOptions::dataType(hZShapeInfo); - switch (opNum) { - case transform::IsMax: { - bool scalarCheat = false; - if (extraParams == nullptr) { - scalarCheat = true; - } + dim3 launchDims(512, 512, 2048); - void* special = lc->getAllocationPointer(); - - if (scalarCheat) { - auto scalarShape = nd4j::ConstantShapeHelper::getInstance()->bufferForShapeInfo(ShapeDescriptor::scalarDescriptor(nd4j::DataType::INT64)); //ShapeBuilders::createScalarShapeInfo(nd4j::DataType::INT64); - /** - * In case of vector-input for IsMax, it just turns into IndexReduce call + further filler call - */ - execIndexReduceScalar(lc, indexreduce::IndexMax, nullptr, hXShapeInfo, dX, dXShapeInfo, extraParams, nullptr, scalarShape.primaryAsT(), special, scalarShape.specialAsT()); - Nd4jLong maxIdx = -119; - nd4j::DebugHelper::checkErrorCode(stream, "IsMax: execIndexReduce(...) failed"); - - cudaMemcpyAsync(&maxIdx, special, sizeof(Nd4jLong), cudaMemcpyDeviceToHost, *stream); - nd4j::DebugHelper::checkErrorCode(stream, "IsMax: cudaMemcpyAsync(...) failed"); - int targetIdx = 0; - - if (shape::order(hXShapeInfo) == 'c' || shape::order(hXShapeInfo) == 'f' && maxIdx * shape::stride(hXShapeInfo)[shape::rank(hXShapeInfo) - 1] >= shape::length(hXShapeInfo)) - targetIdx = maxIdx; - else - targetIdx = maxIdx * shape::stride(hXShapeInfo)[shape::rank(hXShapeInfo) - 1]; - - dim3 launchDims(1, 512, 1024); - BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, dZ, shape::length(hZShapeInfo), targetIdx), LIBND4J_TYPES); - - nd4j::DebugHelper::checkErrorCode(stream, "Legacy IsMax(...) failed"); - - //delete[] scalarShape; - } - } - break; - default: { - dim3 launchDims(512, 512, 16384); - - BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES); - } - } + BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformAny, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, LIBND4J_TYPES); // TODO: remove after the release auto res = cudaStreamSynchronize(*stream); @@ -884,7 +845,7 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext *lc, if (!DataTypeUtils::isR(zType)) throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType); - dim3 launchDims(512, 512, 16384); + dim3 launchDims(512, 512, 2048); BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES); // TODO: remove after the release diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index af9fc6776..b85ce8760 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -653,36 +653,7 @@ void execTransformAny(Nd4jPointer *extraPointers,int opNum, auto streamSpecial = reinterpret_cast(extraPointers[4]); LaunchContext lc(stream, streamSpecial, extraPointers[5], extraPointers[3], reinterpret_cast(extraPointers[6])); - // FIXME: remove this once all operations are enabled - if (opNum == nd4j::transform::IsMax && extraParams != nullptr) { - auto hostYShapeInfo = reinterpret_cast(extraPointers[7]); - auto hostTShapeInfo = reinterpret_cast(extraPointers[19]); - auto tadMaxShapeInfo = reinterpret_cast (extraPointers[10]); - auto tadMaxOffsets = reinterpret_cast (extraPointers[11]); - int *dimension = reinterpret_cast (extraPointers[15]); - int *hDimension = reinterpret_cast (extraPointers[16]); - int dimensionLength = getDeviceId(extraPointers[18]); - auto special = reinterpret_cast(extraPointers[17]); - - auto cshape = ShapeBuilders::createVectorShapeInfo(nd4j::DataType::INT32, dimensionLength); - - // we call for IMax on specified dimension - execIndexReduce(extraPointers, indexreduce::IndexMax, nullptr, hXShapeInfo, dX, dXShapeInfo, extraParams, nullptr, hostTShapeInfo, special, hostYShapeInfo, hDimension, cshape, dimension, nullptr); - - DEBUG_KERNEL(stream, opNum); - - dim3 launchDims(256, 256, 16384); - auto zType = ArrayOptions::dataType(hZShapeInfo); - - // at this point, all IMax indexes are gathered, and we execute filler - BUILD_SINGLE_SELECTOR(zType, fillDimensionalIsMaxGeneric, (launchDims, stream, special, dZ, dZShapeInfo, tadMaxShapeInfo, dimension, dimensionLength, tadMaxOffsets), LIBND4J_TYPES); - - nd4j::DebugHelper::checkErrorCode(stream, "Legacy IsMax(...) failed"); - - delete[] cshape; - } else { - NativeOpExecutioner::execTransformAny(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, nullptr, nullptr); - } + NativeOpExecutioner::execTransformAny(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, nullptr, nullptr); } //////////////////////////////////////////////////////////////////////// @@ -712,7 +683,7 @@ void execTransformFloat(Nd4jPointer *extraPointers,int opNum, auto tadOffsets = reinterpret_cast(extraPointers != nullptr ? extraPointers[11] : nullptr); LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); - NativeOpExecutioner::execTransformFloat(&lc, opNum, hX, hXShapeInfo, dZ, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); + NativeOpExecutioner::execTransformFloat(&lc, opNum, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); } diff --git a/libnd4j/include/array/impl/NDArrayList.cpp b/libnd4j/include/array/impl/NDArrayList.cpp index 1762565a1..75df72e70 100644 --- a/libnd4j/include/array/impl/NDArrayList.cpp +++ b/libnd4j/include/array/impl/NDArrayList.cpp @@ -137,8 +137,8 @@ namespace nd4j { auto newAxis = ShapeUtils::evalDimsToExclude(array->rankOf(), args); auto result = array->allTensorsAlongDimension(newAxis); for (int e = 0; e < result->size(); e++) { - auto chunk = result->at(e)->dup(array->ordering()); - write(e, chunk); + auto chunk = result->at(e);//->dup(array->ordering()); + write(e, chunk->dup(array->ordering())); } delete result; } diff --git a/libnd4j/include/helpers/Loops.h b/libnd4j/include/helpers/Loops.h index 3b5627eec..bda04414f 100644 --- a/libnd4j/include/helpers/Loops.h +++ b/libnd4j/include/helpers/Loops.h @@ -922,7 +922,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::EWS1: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; ++i) { extraParams[0] = param0; @@ -944,7 +944,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::EWSNONZERO: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; ++i) { extraParams[0] = param0; @@ -966,7 +966,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK1: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; i++) { extraParams[0] = param0; @@ -990,7 +990,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK2: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; i++) { extraParams[0] = param0; @@ -1016,7 +1016,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK3: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; i++) { extraParams[0] = param0; @@ -1044,7 +1044,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK4: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; i++) { extraParams[0] = param0; @@ -1074,7 +1074,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK5: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; i++) { extraParams[0] = param0; @@ -1111,7 +1111,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; ++i) { extraParams[0] = param0; @@ -1135,7 +1135,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, uint castYTadShapeInfo[MAX_RANK]; const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo(yTadShapeInfo, castYTadShapeInfo); - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint i = 0; i < zLen; ++i) { extraParams[0] = param0; @@ -1199,7 +1199,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::EWS1: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1224,7 +1224,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::EWSNONZERO: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1249,7 +1249,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK1: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1276,7 +1276,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK2: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1305,7 +1305,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK3: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1336,7 +1336,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK4: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1369,7 +1369,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, //*********************************************// case LoopKind::RANK5: { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1409,7 +1409,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, if(shape::haveSameShapeAndStrides(xTadShapeInfo, yTadShapeInfo)) { - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { @@ -1435,7 +1435,7 @@ void Loops::loopXYZ(const X* x, const Nd4jLong* xShapeInfo, uint castYTadShapeInfo[MAX_RANK]; const bool canCastYTad = nd4j::DataTypeUtils::castShapeInfo(yTadShapeInfo, castYTadShapeInfo); - PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) if(numThreads > 1) private(extraParams)) + PRAGMA_OMP_PARALLEL_FOR_SIMD_ARGS(collapse(2) num_threads(numThreads) OMP_IF(numThreads > 1) private(extraParams)) for (uint ix = 0; ix < numXTads; ++ix) { for (uint iy = 0; iy < numYTads; ++iy) { diff --git a/libnd4j/include/helpers/cpu/MmulHelper.cpp b/libnd4j/include/helpers/cpu/MmulHelper.cpp index d17d2c021..fbf2fbc20 100644 --- a/libnd4j/include/helpers/cpu/MmulHelper.cpp +++ b/libnd4j/include/helpers/cpu/MmulHelper.cpp @@ -40,7 +40,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c const bool flagA = (flagC && transA) || (!flagC && !transA); const bool flagB = (flagC && transB) || (!flagC && !transB); - // PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) + // PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // for(uint row = 0; row < M; ++row) { // T3* c = flagC ? (C + row) : (C + row * ldc); @@ -74,7 +74,7 @@ static void usualGemm(const char cOrder, const bool transA, const bool transB, c // } // } - PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2)) + PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M*N > Environment::getInstance()->elementwiseThreshold()) schedule(guided) collapse(2)) for(uint row = 0; row < M; ++row) { for(uint col = 0; col < N; ++col) { @@ -108,7 +108,7 @@ static void usualGemv(const char aOrder, const int M, const int N, const double const bool flagA = aOrder == 'f'; - PRAGMA_OMP_PARALLEL_FOR_ARGS(if(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) + PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(M > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) for(int row = 0; row < M; ++row) { T3* y = Y + row * incy; @@ -139,7 +139,7 @@ static void usualDot(const Nd4jLong length, const double alpha, const void* vX, T3 alphaZ(alpha), betaZ(beta); T3 sum = 0; - PRAGMA_OMP_PARALLEL_FOR_ARGS(if(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum)) + PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(length > Environment::getInstance()->elementwiseThreshold()) schedule(guided) reduction(OMP_SUMT:sum)) for(int i = 0; i < length; ++i) sum = sum + X[i * incx] * Y[i * incy]; diff --git a/libnd4j/include/loops/cuda/specials/fillIsMax.cu b/libnd4j/include/loops/cuda/specials/fillIsMax.cu index b80f0036c..0851968ba 100644 --- a/libnd4j/include/loops/cuda/specials/fillIsMax.cu +++ b/libnd4j/include/loops/cuda/specials/fillIsMax.cu @@ -25,21 +25,21 @@ namespace nd4j { //////////////////////////////////////////////////////////////////////// template - __global__ void execFillIsMax(void *vdZ, Nd4jLong length, long idx) { + __global__ void execFillIsMax(void *vdZ, Nd4jLong *xShapeInfo, Nd4jLong length, long idx) { auto dz = reinterpret_cast(vdZ); int tid = blockIdx.x * blockDim.x + threadIdx.x; for (Nd4jLong i = tid; i < length; i += blockDim.x * gridDim.x) - dz[i] = (i == idx ? (T) 1 : (T) 0); + dz[shape::getIndexOffset(i, xShapeInfo, length)] = (i == idx ? (T) 1 : (T) 0); } //////////////////////////////////////////////////////////////////////// template - __host__ void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong length, long idx) { - execFillIsMax<<>>(dx, length, idx); + __host__ void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong *xShapeInfo, Nd4jLong length, long idx) { + execFillIsMax<<>>(dx, xShapeInfo, length, idx); nd4j::DebugHelper::checkErrorCode(stream, "fillIsMax(...) failed"); } - BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT fillIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void* dz, Nd4jLong length, long idx), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void ND4J_EXPORT fillIsMaxGeneric, (dim3& launchDims, cudaStream_t *stream, void* dz, Nd4jLong *zShapeInfo, Nd4jLong length, long idx), LIBND4J_TYPES); } \ No newline at end of file diff --git a/libnd4j/include/loops/cuda/transform/transform_float.cu b/libnd4j/include/loops/cuda/transform/transform_float.cu index e1cd36256..6fe7b18d1 100644 --- a/libnd4j/include/loops/cuda/transform/transform_float.cu +++ b/libnd4j/include/loops/cuda/transform/transform_float.cu @@ -99,18 +99,18 @@ namespace functions { if(xEws > 0 && zEws > 0 && xOrder == zOrder) { - for (int i = tid; i < length; i += totalThreads) - z[i * zEws] = OpType::op(x[i * xEws], params); + for (Nd4jLong i = tid; i < length; i += totalThreads) + z[i * zEws] = OpType::op(x[i * xEws], params); } else { if(vx == vz) { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); z[xOffset] = OpType::op(x[xOffset], params); } } else { - for (Nd4jLong i = tid; i < length; i+= gridDim.x * blockDim.x) { + for (Nd4jLong i = tid; i < length; i+= totalThreads) { auto xOffset = shape::getIndexOffset(i, xShapeInfo, length); auto zOffset = shape::getIndexOffset(i, zShapeInfo, length); z[zOffset] = OpType::op(x[xOffset], params); diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index 760b30bbd..b3096ac0e 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -92,8 +92,7 @@ (21, Copy) #define TRANSFORM_ANY_OPS \ - (0, Assign) , \ - (1, IsMax) + (0, Assign) // these ops return bool #define TRANSFORM_BOOL_OPS \ diff --git a/libnd4j/include/loops/special_kernels.h b/libnd4j/include/loops/special_kernels.h index 4dc9b083c..37356efcd 100644 --- a/libnd4j/include/loops/special_kernels.h +++ b/libnd4j/include/loops/special_kernels.h @@ -36,7 +36,7 @@ namespace nd4j { template - _CUDA_H void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong length, long idx); + _CUDA_H void fillIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dx, Nd4jLong *xShapeInfo, Nd4jLong length, long idx); template _CUDA_H void fillDimensionalIsMaxGeneric(dim3 &launchDims, cudaStream_t *stream, void *dX, void *dZ, Nd4jLong *zShapeInfo, Nd4jLong *tadOnlyShapeInfo, int *dimension, int dimensionLength, Nd4jLong *tadOffsets); diff --git a/libnd4j/include/op_boilerplate.h b/libnd4j/include/op_boilerplate.h index ad37a1618..d9c8dee62 100644 --- a/libnd4j/include/op_boilerplate.h +++ b/libnd4j/include/op_boilerplate.h @@ -1328,7 +1328,8 @@ REGISTER_C(NAME) \ nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \ auto shapeList = SHAPELIST(); \ - for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { \ + auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); \ + for (int e = 0; e < opLimit; e++) { \ auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); \ shapeList->push_back(newshape); \ } \ @@ -1365,7 +1366,8 @@ REGISTER_C(NAME) \ nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \ auto shapeList = SHAPELIST(); \ - for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { \ + auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); \ + for (int e = 0; e < opLimit; e++) { \ Nd4jLong* newshape; \ COPY_SHAPE(inputShape->at(0), newshape); \ shapeList->push_back(CONSTANT(newshape)); \ @@ -1388,7 +1390,8 @@ REGISTER_C(NAME) \ nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { \ auto shapeList = SHAPELIST(); \ - for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { \ + auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); \ + for (int e = 0; e < opLimit; e++) { \ auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); \ shapeList->push_back(newshape); \ } \ diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp new file mode 100644 index 000000000..2aac5c6f9 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_rshift.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_cyclic_rshift_bits) + +#include +#include +#include + +namespace nd4j { + namespace ops { + CONFIGURABLE_OP_IMPL(cyclic_rshift_bits, 1, 1, true, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_rshift_bits: actual shift value is missing"); + + uint32_t shift = 0; + if (block.width() > 1) { + shift = INPUT_VARIABLE(1)->e(0); + } else if (block.numI() > 0) { + shift = INT_ARG(0); + }; + + helpers::cyclic_rshift_bits(block.launchContext(), *input, *output, shift); + + REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_rshift_bits: can't shift beyond size of data type") + + return Status::OK(); + } + + DECLARE_TYPES(cyclic_rshift_bits) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp new file mode 100644 index 000000000..0bdb9503d --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/cyclic_shift.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_cyclic_shift_bits) + +#include +#include +#include + +namespace nd4j { + namespace ops { + CONFIGURABLE_OP_IMPL(cyclic_shift_bits, 1, 1, true, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "cyclic_shift_bits: actual shift value is missing"); + + uint32_t shift = 0; + if (block.width() > 1) { + shift = INPUT_VARIABLE(1)->e(0); + } else if (block.numI() > 0) { + shift = INT_ARG(0); + }; + + helpers::cyclic_shift_bits(block.launchContext(), *input, *output, shift); + + REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "cyclic_shift_bits: can't shift beyond size of data type") + + return Status::OK(); + } + + DECLARE_TYPES(cyclic_shift_bits) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp new file mode 100644 index 000000000..4068351a2 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/rshift.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_rshift_bits) + +#include +#include +#include + +namespace nd4j { + namespace ops { + CONFIGURABLE_OP_IMPL(rshift_bits, 1, 1, true, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "rshift_bits: actual shift value is missing"); + + uint32_t shift = 0; + if (block.width() > 1) { + shift = INPUT_VARIABLE(1)->e(0); + } else if (block.numI() > 0) { + shift = INT_ARG(0); + }; + + REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "rshift_bits: can't shift beyond size of data type") + + helpers::rshift_bits(block.launchContext(), *input, *output, shift); + + return Status::OK(); + } + + DECLARE_TYPES(rshift_bits) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp new file mode 100644 index 000000000..f79da1024 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/bitwise/shift.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#if NOT_EXCLUDED(OP_shift_bits) + +#include +#include +#include + +namespace nd4j { + namespace ops { + CONFIGURABLE_OP_IMPL(shift_bits, 1, 1, true, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(block.numI() > 0 || block.width() > 1, 0, "shift_bits: actual shift value is missing"); + + uint32_t shift = 0; + if (block.width() > 1) { + shift = INPUT_VARIABLE(1)->e(0); + } else if (block.numI() > 0) { + shift = INT_ARG(0); + }; + + REQUIRE_TRUE(shift > 0 && shift < input->sizeOfT() * 8, 0, "shift_bits: can't shift beyond size of data type") + + helpers::shift_bits(block.launchContext(), *input, *output, shift); + + return Status::OK(); + } + + DECLARE_TYPES(shift_bits) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_INTS}) + ->setSameMode(true); + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp b/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp index d554d46ef..4aaae3c0d 100644 --- a/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp +++ b/libnd4j/include/ops/declarable/generic/bitwise/toggle_bits.cpp @@ -34,7 +34,7 @@ namespace nd4j { auto z = OUTPUT_VARIABLE(i); REQUIRE_TRUE(x->dataType() == z->dataType(), 0, "Toggle bits requires input and output to have same type"); - REQUIRE_TRUE(x->isR(),0, "Toggle bits requires input and output to be integer type (int8, int16, int32, int64)"); + REQUIRE_TRUE(x->isZ(),0, "Toggle bits requires input and output to be integer type (int8, int16, int32, int64)"); helpers::__toggle_bits(block.launchContext(), *x, *z); } @@ -44,7 +44,8 @@ namespace nd4j { DECLARE_TYPES(toggle_bits) { getOpDescriptor() ->setAllowedInputTypes({ALL_INTS}) - ->setSameMode(true); + ->setAllowedOutputTypes({ALL_INTS}) + ->setSameMode(false); } } } diff --git a/libnd4j/include/ops/declarable/generic/convo/ismax.cpp b/libnd4j/include/ops/declarable/generic/convo/ismax.cpp index 616f9842b..ad5a485e1 100644 --- a/libnd4j/include/ops/declarable/generic/convo/ismax.cpp +++ b/libnd4j/include/ops/declarable/generic/convo/ismax.cpp @@ -28,7 +28,7 @@ namespace nd4j { namespace ops { -CONFIGURABLE_OP_IMPL(ismax, 1, 1, true, 0, -1) { +CONFIGURABLE_OP_IMPL(ismax, 1, 1, true, 0, -2) { auto x = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_double.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_double.cpp index 59a7b7546..9d3e5aaed 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_double.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_double.cpp @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { - OP_IMPL(to_double, 1, 1, true) { + CUSTOM_OP_IMPL(to_double, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -42,6 +42,12 @@ namespace nd4j { ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::DOUBLE); } + + DECLARE_SHAPE_FN(to_double) { + auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::DOUBLE, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); + } + } } diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_float16.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_float16.cpp index 8cdd38e4f..d6818bec4 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_float16.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_float16.cpp @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { - OP_IMPL(to_float16, 1, 1, true) { + CUSTOM_OP_IMPL(to_float16, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -42,6 +42,12 @@ namespace nd4j { ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::HALF); } + + DECLARE_SHAPE_FN(to_float16) { + auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::HALF, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); + } + } } diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_float32.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_float32.cpp index 3fdcafaab..4ca46bb82 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_float32.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_float32.cpp @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { - OP_IMPL(to_float32, 1, 1, true) { + CUSTOM_OP_IMPL(to_float32, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -42,6 +42,12 @@ namespace nd4j { ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::FLOAT32); } + + DECLARE_SHAPE_FN(to_float32) { + auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::FLOAT32, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); + } + } } diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_int32.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_int32.cpp index a5eef8595..897868be5 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_int32.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_int32.cpp @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { - OP_IMPL(to_int32, 1, 1, true) { + CUSTOM_OP_IMPL(to_int32, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -42,6 +42,11 @@ namespace nd4j { ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::INT32); } + DECLARE_SHAPE_FN(to_int32) { + auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::INT32, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); + } + } } diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_int64.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_int64.cpp index 450c57c1d..6fa728254 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_int64.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_int64.cpp @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { - OP_IMPL(to_int64, 1, 1, true) { + CUSTOM_OP_IMPL(to_int64, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -42,6 +42,11 @@ namespace nd4j { ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::INT64); } + DECLARE_SHAPE_FN(to_int64) { + auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::INT64, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); + } + } } diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_uint32.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_uint32.cpp index 5b6822797..6805855f1 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_uint32.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_uint32.cpp @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { - OP_IMPL(to_uint32, 1, 1, true) { + CUSTOM_OP_IMPL(to_uint32, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -40,8 +40,13 @@ namespace nd4j { DECLARE_TYPES(to_uint32) { getOpDescriptor() ->setAllowedInputTypes(nd4j::DataType::ANY) - ->setAllowedOutputTypes(nd4j::DataType::INT16); + ->setAllowedOutputTypes(nd4j::DataType::INT32); } + DECLARE_SHAPE_FN(to_uint32) { + auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::UINT32, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); + } + } } diff --git a/libnd4j/include/ops/declarable/generic/datatypes/to_uint64.cpp b/libnd4j/include/ops/declarable/generic/datatypes/to_uint64.cpp index a0402cdb7..fe61821a5 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/to_uint64.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/to_uint64.cpp @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { - OP_IMPL(to_uint64, 1, 1, true) { + CUSTOM_OP_IMPL(to_uint64, 1, 1, true, 0, 0) { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -42,6 +42,10 @@ namespace nd4j { ->setAllowedInputTypes(nd4j::DataType::ANY) ->setAllowedOutputTypes(nd4j::DataType::INT8); } + DECLARE_SHAPE_FN(to_uint64) { + auto outShape = ShapeBuilders::copyShapeInfoAndType(inputShape->at(0), DataType::UINT64, true, block.workspace()); + return SHAPELIST(CONSTANT(outShape)); + } } } diff --git a/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp b/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp index 886959943..b5e5f207e 100644 --- a/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp +++ b/libnd4j/include/ops/declarable/generic/list/unstack_list.cpp @@ -26,13 +26,19 @@ namespace nd4j { namespace ops { LIST_OP_IMPL(unstack_list, 1, 1, 0, 0) { - auto input = INPUT_VARIABLE(0); + auto outputList = INPUT_LIST(0); + auto input = INPUT_VARIABLE(int(outputList != nullptr) ); - auto list = new NDArrayList(0, true); - list->unstack(input, 0); + if (outputList == nullptr) { + outputList = new NDArrayList(0, true); + //block.trackList(outputList); + setupResultList(outputList, block); + } + outputList->unstack(input, INT_ARG(0)); //OVERWRITE_RESULT(list); - setupResultList(list, block); + + // return Status::OK(); } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp index 7c4a52f9c..61f592f1d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/roll.cpp @@ -23,40 +23,74 @@ #include #include +#include namespace nd4j { namespace ops { - CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 1) { + CONFIGURABLE_OP_IMPL(roll, 1, 1, true, 0, 0) { auto output = OUTPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0); - bool shiftIsLinear = true; - //std::vector axes(input->rankOf()); - int shift = INT_ARG(0); int inputLen = input->lengthOf(); - if (block.isInplace()) output = input; - if (shift < 0) { - // convert shift to positive value between 1 and inputLen - 1 - shift -= inputLen * (shift / inputLen - 1); - } - else - // cut shift to value between 1 and inputLen - 1 - shift %= inputLen; - if (block.numI() > 1) - shiftIsLinear = false; - if (shiftIsLinear) { - helpers::rollFunctorLinear(block.launchContext(), input, output, shift, block.isInplace()); + bool shiftIsLinear = block.width() == 1; + std::vector axes; + std::vector shifts; + if (block.width() > 1) { + REQUIRE_TRUE(block.width() == 3, 0, "roll: 3 arguments required for roll - input, shifts and axes. But %i given.", block.width()); + auto axesI = INPUT_VARIABLE(2); + auto shiftsI = INPUT_VARIABLE(1); + REQUIRE_TRUE(axesI->rankOf() == shiftsI->rankOf(), 0, "roll: shifts and axes should be the same rank, but %i and %i given.", (int)shiftsI->rankOf(), (int)axesI->rankOf()); + REQUIRE_TRUE(axesI->lengthOf() == shiftsI->lengthOf(), 0, "roll: shifts and axes should be the same length, but %i and %i given.", (int)shiftsI->lengthOf(), (int)axesI->lengthOf()); + helpers::adjustAxis(axesI->lengthOf(), axesI, axes ); + shifts.resize(shiftsI->lengthOf()); + for (Nd4jLong i = 0; i < shiftsI->lengthOf(); i++) { + auto shift = shiftsI->e(i); + if (shift < 0) { + shift -= input->sizeAt(i) * (shift / inputLen - 1); + } + else { + shift %= input->sizeAt(i); + } + shifts[i] = shift; + } + } else { - std::vector axes(block.numI() - 1); - for (unsigned e = 0; e < axes.size(); ++e) { - int axe = INT_ARG(e + 1); - REQUIRE_TRUE(axe < input->rankOf() && axe >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.", - input->rankOf(), input->rankOf() - 1, axe); - axes[e] = (axe < 0? (input->rankOf() + axe) : axe); + int shift = INT_ARG(0); + if (shift < 0) { + // convert shift to positive value between 1 and inputLen - 1 + shift -= inputLen * (shift / inputLen - 1); } - helpers::rollFunctorFull(block.launchContext(), input, output, shift, axes, block.isInplace()); + else + // cut shift to value between 1 and inputLen - 1 + shift %= inputLen; + axes.resize(block.getIArguments()->size() - 1); + if (axes.size()) + shifts.resize(axes.size());//emplace_back(shift); + else + shifts.push_back(shift); + + for (auto& s: shifts) + s = shift; + + for (unsigned e = 0; e < axes.size(); ++e) { + int axis = INT_ARG(e + 1); + REQUIRE_TRUE(axis < input->rankOf() && axis >= -input->rankOf(), 0, "roll: axe value should be between -%i and %i, but %i was given.", + input->rankOf(), input->rankOf() - 1, axis); + axes[e] = (axis < 0? (input->rankOf() + axis) : axis); + } + } + + if (block.isInplace()) output = input; + + shiftIsLinear = axes.size() == 0; + + if (shiftIsLinear) { + helpers::rollFunctorLinear(block.launchContext(), input, output, shifts[0], block.isInplace()); + } + else { + helpers::rollFunctorFull(block.launchContext(), input, output, shifts, axes, block.isInplace()); } return Status::OK(); @@ -64,7 +98,9 @@ namespace ops { DECLARE_TYPES(roll) { getOpDescriptor() - ->setAllowedInputTypes(nd4j::DataType::ANY) + ->setAllowedInputTypes(0,nd4j::DataType::ANY) + ->setAllowedInputTypes(1,nd4j::DataType::INT32) // TODO: all ints in future + ->setAllowedInputTypes(2,nd4j::DataType::INT32) ->setAllowedOutputTypes(nd4j::DataType::ANY) ->setSameMode(true); } diff --git a/libnd4j/include/ops/declarable/generic/random/get_seed.cpp b/libnd4j/include/ops/declarable/generic/random/get_seed.cpp index e8acfa067..2161a2378 100644 --- a/libnd4j/include/ops/declarable/generic/random/get_seed.cpp +++ b/libnd4j/include/ops/declarable/generic/random/get_seed.cpp @@ -26,11 +26,11 @@ namespace nd4j { namespace ops { CUSTOM_OP_IMPL(get_seed, -2, 1, false, 0, 0) { - REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph"); - auto rng = block.getRNG(); +// REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph"); + auto rng = block.getRng(); auto z = OUTPUT_VARIABLE(0); - z->p(Nd4jLong(0), rng->getSeed()); + z->p(Nd4jLong(0), rng.rootState()); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/random/set_seed.cpp b/libnd4j/include/ops/declarable/generic/random/set_seed.cpp index b42c7c763..fa9dcf992 100644 --- a/libnd4j/include/ops/declarable/generic/random/set_seed.cpp +++ b/libnd4j/include/ops/declarable/generic/random/set_seed.cpp @@ -27,8 +27,9 @@ namespace nd4j { namespace ops { CUSTOM_OP_IMPL(set_seed, -2, 1, false, 0, -2) { - REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph"); - auto rng = block.getRNG(); +// REQUIRE_TRUE(block.getRNG() != nullptr, 0, "RNG should be defined in Graph"); + auto rng = block.getRng(); //.getRNG(); + Nd4jLong seed = 0; if (block.getIArguments()->size() > 0) { seed = INT_ARG(0); @@ -41,8 +42,8 @@ namespace nd4j { } // FIXME: this approach isn't really good for cuda, since it'll assume that CUDA might get nullptr instead of stream - refreshBuffer(nullptr, seed, (Nd4jPointer) rng); - + //refreshBuffer(nullptr, seed, (Nd4jPointer) rng); + rng.setSeed((int)seed); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp b/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp index 1398eae47..3d45bcf42 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/log1p.cpp @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { - OP_IMPL(Log1p, 2, 1, true) { + OP_IMPL(Log1p, 1, 1, true) { auto x = INPUT_VARIABLE(0); auto z = OUTPUT_VARIABLE(0); diff --git a/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp b/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp index f087eaf1b..e48761f8f 100644 --- a/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp +++ b/libnd4j/include/ops/declarable/generic/transforms/merge_max_idx.cpp @@ -27,7 +27,7 @@ namespace nd4j { namespace ops { -OP_IMPL(mergemaxindex, -1, 1, false) { +CUSTOM_OP_IMPL(mergemaxindex, -1, 1, false, 0, 0) { REQUIRE_OK(this->validateInputDimensionsMatch(block)); auto output = OUTPUT_VARIABLE(0); @@ -49,6 +49,15 @@ DECLARE_SYN(MergeMaxIndex, mergemaxindex); ->setAllowedInputTypes({ALL_INTS, ALL_FLOATS}); } } +DECLARE_SHAPE_FN(mergemaxindex) { + auto in = inputShape->at(0); + auto dtype = DataType::INT32; + if (block.getIArguments()->size()> 0) + dtype = (DataType)INT_ARG(0); + + auto resShape = ShapeBuilders::copyShapeInfoAndType(in, dtype, block.workspace()); + return SHAPELIST(CONSTANT(resShape)); +} } #endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/bitwise.h b/libnd4j/include/ops/declarable/headers/bitwise.h index 466431c83..900d42816 100644 --- a/libnd4j/include/ops/declarable/headers/bitwise.h +++ b/libnd4j/include/ops/declarable/headers/bitwise.h @@ -28,13 +28,58 @@ namespace nd4j { /** * This operation toggles individual bits of each element in array * - * PLEASE NOTE: This operation is possible only on integer datatypes + * PLEASE NOTE: This operation is possible only on integer data types * * @tparam T */ #if NOT_EXCLUDED(OP_toggle_bits) DECLARE_OP(toggle_bits, -1, -1, true); #endif + + + /** + * This operation shift individual bits of each element in array to the left: << + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_shift_bits) + DECLARE_CONFIGURABLE_OP(shift_bits, 1, 1, true, 0, -2); + #endif + + /** + * This operation shift individual bits of each element in array to the right: >> + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_rshift_bits) + DECLARE_CONFIGURABLE_OP(rshift_bits, 1, 1, true, 0, -2); + #endif + + /** + * This operation shift individual bits of each element in array, shifting to the left + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_cyclic_shift_bits) + DECLARE_CONFIGURABLE_OP(cyclic_shift_bits, 1, 1, true, 0, -2); + #endif + + /** + * This operation shift individual bits of each element in array, shifting to the right + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * @tparam T + */ + #if NOT_EXCLUDED(OP_cyclic_rshift_bits) + DECLARE_CONFIGURABLE_OP(cyclic_rshift_bits, 1, 1, true, 0, -2); + #endif } } diff --git a/libnd4j/include/ops/declarable/headers/convo.h b/libnd4j/include/ops/declarable/headers/convo.h index ee1417386..bd262a7c1 100644 --- a/libnd4j/include/ops/declarable/headers/convo.h +++ b/libnd4j/include/ops/declarable/headers/convo.h @@ -260,7 +260,7 @@ namespace nd4j { * 0: axis */ #if NOT_EXCLUDED(OP_ismax) - DECLARE_CONFIGURABLE_OP(ismax, 1, 1, true, 0, -1); + DECLARE_CONFIGURABLE_OP(ismax, 1, 1, true, 0, -2); #endif /** diff --git a/libnd4j/include/ops/declarable/headers/datatypes.h b/libnd4j/include/ops/declarable/headers/datatypes.h index 43983ecb6..d8ff39d48 100644 --- a/libnd4j/include/ops/declarable/headers/datatypes.h +++ b/libnd4j/include/ops/declarable/headers/datatypes.h @@ -30,7 +30,7 @@ namespace nd4j { * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ #if NOT_EXCLUDED(OP_to_double) - DECLARE_OP(to_double, 1, 1, true); + DECLARE_CUSTOM_OP(to_double, 1, 1, true, 0, 0); #endif /** @@ -39,7 +39,7 @@ namespace nd4j { * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ #if NOT_EXCLUDED(OP_to_float16) - DECLARE_OP(to_float16, 1, 1, true); + DECLARE_CUSTOM_OP(to_float16, 1, 1, true, 0, 0); #endif /** @@ -48,7 +48,7 @@ namespace nd4j { * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ #if NOT_EXCLUDED(OP_to_float32) - DECLARE_OP(to_float32, 1, 1, true); + DECLARE_CUSTOM_OP(to_float32, 1, 1, true, 0, 0); #endif /** @@ -57,7 +57,7 @@ namespace nd4j { * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ #if NOT_EXCLUDED(OP_to_int32) - DECLARE_OP(to_int32, 1, 1, true); + DECLARE_CUSTOM_OP(to_int32, 1, 1, true, 0, 0); #endif /** @@ -66,7 +66,7 @@ namespace nd4j { * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ #if NOT_EXCLUDED(OP_to_int64) - DECLARE_OP(to_int64, 1, 1, true); + DECLARE_CUSTOM_OP(to_int64, 1, 1, true, 0, 0); #endif /** @@ -75,7 +75,7 @@ namespace nd4j { * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ #if NOT_EXCLUDED(OP_to_uint32) - DECLARE_OP(to_uint32, 1, 1, true); + DECLARE_CUSTOM_OP(to_uint32, 1, 1, true, 0, 0); #endif /** @@ -84,7 +84,7 @@ namespace nd4j { * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ #if NOT_EXCLUDED(OP_to_uint64) - DECLARE_OP(to_uint64, 1, 1, true); + DECLARE_CUSTOM_OP(to_uint64, 1, 1, true, 0, 0); #endif /** diff --git a/libnd4j/include/ops/declarable/headers/transforms.h b/libnd4j/include/ops/declarable/headers/transforms.h index b24fad482..75715f78e 100644 --- a/libnd4j/include/ops/declarable/headers/transforms.h +++ b/libnd4j/include/ops/declarable/headers/transforms.h @@ -65,9 +65,15 @@ namespace nd4j { #if NOT_EXCLUDED(OP_mergemax) DECLARE_OP(mergemax, -1, 1, false); #endif - + /* + * Complete tensor with max indices merged from all input tensors list + * + * INPUT: tensors with the same shape + * OUTPUT: integer tensor with the same shape + * INT_ARG: result type (one of int), INT32 by default + */ #if NOT_EXCLUDED(OP_mergemaxindex) - DECLARE_OP(mergemaxindex, -1, 1, false); + DECLARE_CUSTOM_OP(mergemaxindex, -1, 1, false, 0, 0); #endif #if NOT_EXCLUDED(OP_mergeadd) diff --git a/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp b/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp index 6f0ed5f27..eb56acb9c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/axis.cpp @@ -43,7 +43,6 @@ namespace helpers { axisVector[e] = a + rank; } } - } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp index e363fd8fa..da3cb3259 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/roll.cpp @@ -85,18 +85,19 @@ namespace helpers { } } - void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector const& axes, bool inplace){ + void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace){ if (!inplace) output->assign(input); auto source = output; //input; - for (int axe: axes) { + for (auto i = 0; i < axes.size(); i++) { + int axe = axes[i]; if (axe == source->rankOf() - 1) {// last dimension std::unique_ptr listOfTensors(source->allTensorsAlongDimension({axe})); std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension({axe})); int fullLen = listOfTensors->size(); - int theShift = shift; + int theShift = shifts[i]; if (theShift > 0) { theShift %= fullLen; } @@ -118,7 +119,7 @@ namespace helpers { int fullLen = listOfTensors->size(); int sizeAt = input->sizeAt(axe); - int theShift = shift; + int theShift = shifts[i]; if (theShift > 0) { theShift %= sizeAt; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp index 20d9dd05d..c1d01930c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/scatter.cpp @@ -35,8 +35,8 @@ void scatter(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& ind if(outRank == 1) { -// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) -PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) +// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) +PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided)) for(Nd4jLong i = 0; i < indLen; ++i) { Nd4jLong idx = indices.e(i); @@ -54,8 +54,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) std::vector dimsToExcludeUpd(sizeOfDims); std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); -// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug ! -PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) +// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) // causes known openMP asan bug ! +PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided)) for(Nd4jLong i = 0; i < indLen; ++i) { NDArray outSubArr = output(indices.e(i), std::vector({0})); @@ -76,8 +76,8 @@ void scatterND(nd4j::LaunchContext *context, pairwise::Ops op, const NDArray& i if(outRank == 1) { -// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) -PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) +// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen > Environment::getInstance()->elementwiseThreshold()) schedule(guided)) +PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided)) for(Nd4jLong i = 0; i < indLen; ++i) { Nd4jLong idx = indices.e(i); @@ -93,8 +93,8 @@ PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided)) std::iota(dimsToExcludeUpd.begin(), dimsToExcludeUpd.end(), 0); std::vector idxRangeOut(2*outRank, 0); -// PRAGMA_OMP_PARALLEL_FOR_ARGS(if(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut)) -PRAGMA_OMP_PARALLEL_FOR_ARGS(if(!lock) schedule(guided) firstprivate(idxRangeOut)) +// PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(indLen/indLastDim > Environment::getInstance()->elementwiseThreshold()) schedule(guided) firstprivate(idxRangeOut)) +PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(!lock) schedule(guided) firstprivate(idxRangeOut)) for(Nd4jLong i = 0; i < indLen/indLastDim; ++i) { NDArray indSubArr = indices(i, dimsToExcludeInd); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp index 261a742da..6ebfd9b07 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp @@ -479,7 +479,7 @@ namespace helpers { for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { auto outputT = listOfOutTensors->at(fi->first); outputT->assign(listOfTensors->at(fi->second.at(0))); - auto loopSize = fi->second.size(); + Nd4jLong loopSize = fi->second.size(); PRAGMA_OMP_PARALLEL_FOR for (Nd4jLong idx = 1; idx < loopSize; ++idx) { auto current = listOfTensors->at(fi->second.at(idx)); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp new file mode 100644 index 000000000..7a9b77b66 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/shift.cpp @@ -0,0 +1,81 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + void rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { + return x >> shift; + }; + + input.applyLambda(lambda, &output); + } + + void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + + template + void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { + return x << shift; + }; + + input.applyLambda(lambda, &output); + } + + void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + + template + void cyclic_rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { + return x >> shift | x << step; + }; + + input.applyLambda(lambda, &output); + } + + void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + + template + void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { + return x << shift | x >> step; + }; + + input.applyLambda(lambda, &output); + } + + void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp index 3536f9f62..9b96f34d2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/transforms.cpp @@ -562,7 +562,7 @@ static void gatherND_(NDArray& input, NDArray& indices, NDArray& output) { std::vector coords(maxRank); - PRAGMA_OMP_PARALLEL_FOR_ARGS(if(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords)) + PRAGMA_OMP_PARALLEL_FOR_ARGS(OMP_IF(zLen > Environment::getInstance()->elementwiseThreshold()) firstprivate(coords)) for (Nd4jLong i = 0; i < zLen; ++i) { Nd4jLong *zCoordStart, *xCoordStart; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/axis.cu b/libnd4j/include/ops/declarable/helpers/cuda/axis.cu index 6f0ed5f27..a3b2bcd32 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/axis.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/axis.cu @@ -27,6 +27,8 @@ namespace helpers { void adjustAxis(Nd4jLong rank, NDArray* axisVector, std::vector& output) { output.resize(axisVector->lengthOf()); + axisVector->tickReadDevice(); + axisVector->syncToHost(); for (int e = 0; e < axisVector->lengthOf(); e++) { auto ca = axisVector->e(e); if (ca < 0) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu index 3051de448..94675c587 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu @@ -29,104 +29,102 @@ #include #include -namespace nd4j { - namespace ops { - namespace helpers { +namespace nd4j { +namespace ops { +namespace helpers { + /////////////////////////////////////////////////////////////////// - template - __global__ static void concatCuda(const int numOfArrs, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) { +template +__global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) { - __shared__ int arrIdx, blocksPerArr; + T* z = reinterpret_cast(vz); + __shared__ Nd4jLong zLen, totalThreads, *sharedMem; + __shared__ int rank; - if (threadIdx.x == 0) { + if (threadIdx.x == 0) { + extern __shared__ unsigned char shmem[]; + sharedMem = reinterpret_cast(shmem); - blocksPerArr = (gridDim.x + numOfArrs - 1) / numOfArrs; // ceil - arrIdx = blockIdx.x / blocksPerArr; - } - - __syncthreads(); - - for(int j = arrIdx; j < numOfArrs; j += gridDim.x) { - - const auto* x = reinterpret_cast(reinterpret_cast(pVx)[j]); - auto* z = reinterpret_cast(reinterpret_cast(pVz)[j]); - const auto* xShapeInfo = reinterpret_cast(pxShapeInfo)[j]; - const auto* zShapeInfo = reinterpret_cast(pzShapeInfo)[j]; - - const auto arrLen = shape::length(xShapeInfo); - - const auto arrLenPerBlock = (arrLen + blocksPerArr - 1) / blocksPerArr; // ceil - - const auto start = (blockIdx.x % blocksPerArr) * arrLenPerBlock; - const auto end = (start + arrLenPerBlock) > arrLen ? arrLen : (start + arrLenPerBlock); - - for (Nd4jLong i = start + threadIdx.x; i < end; i += blockDim.x) - z[shape::getIndexOffset(i, zShapeInfo, arrLen)] = x[shape::getIndexOffset(i, xShapeInfo, arrLen)]; - } - } - -/////////////////////////////////////////////////////////////////// - template - __host__ static void concatCudaLauncher(const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo) { - - concatCuda<<<512, 512, 512, *stream>>>(numOfArrs, pVx, pxShapeInfo, pVz, pzShapeInfo); - } - BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int numOfArrs, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* pVz, void* pzShapeInfo), LIBND4J_TYPES); - - ////////////////////////////////////////////////////////////////////////// - void concat(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis) { - - const int numOfArrs = inArrs.size(); - for(int i = 0; i < numOfArrs; ++i) - if(!inArrs[i]->isActualOnDeviceSide()) inArrs[i]->syncToDevice(); - - const int rank = inArrs[0]->rankOf(); - const int rank2 = 2*rank; - std::vector> indices(numOfArrs, std::vector(rank2,0)); - - // take into account indices for first array - indices[0][2 * axis + 1] = inArrs[0]->sizeAt(axis); - - // loop through the rest of input arrays - for(int i = 1; i < numOfArrs; ++i) { - indices[i][2 * axis] = indices[i-1][2 * axis + 1]; // index start from - indices[i][2 * axis + 1] = indices[i-1][2 * axis + 1] + inArrs[i]->sizeAt(axis); // index end with (excluding) - } - - std::vector outSubArrs(numOfArrs); - for(int i = 0; i < numOfArrs; ++i) - outSubArrs[i] = new NDArray(output(indices[i], true)); - - // prepare arrays of pointers on buffers and shapes - std::vector hOutBuffers(numOfArrs), hInBuffers(numOfArrs); - std::vector hOutShapeInfo(numOfArrs), hInShapeInfo(numOfArrs); - for(int i = 0; i < numOfArrs; ++i) { - hOutBuffers[i] = outSubArrs[i]->getSpecialBuffer(); - hInBuffers[i] = inArrs[i]->getSpecialBuffer(); - hOutShapeInfo[i] = outSubArrs[i]->getSpecialShapeInfo(); - hInShapeInfo[i] = inArrs[i]->getSpecialShapeInfo(); - } - - // allocate and copy all buffers and shapes arrays to global memory - PointersManager manager(context, "helpers::concat"); - void* dOutBuffers = manager.replicatePointer(hOutBuffers.data(), hOutBuffers.size() * sizeof(void*)); - void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*)); - void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*)); - void* dOutShapeInfo = manager.replicatePointer(hOutShapeInfo.data(), hOutShapeInfo.size() * sizeof(Nd4jLong*)); - - BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (numOfArrs, context->getCudaStream(), dInBuffers, dInShapeInfo, dOutBuffers, dOutShapeInfo), LIBND4J_TYPES); - - manager.synchronize(); - - for(int i = 0; i < numOfArrs; ++i) - delete outSubArrs[i]; - - for(int i = 0; i < numOfArrs; ++i) - inArrs[i]->tickReadHost(); - - output.tickWriteDevice(); - } - } + zLen = shape::length(zShapeInfo); + rank = shape::rank(zShapeInfo); + totalThreads = gridDim.x * blockDim.x; } + + __syncthreads(); + + const auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + if(tid >= zLen) + return; + + auto coords = sharedMem + threadIdx.x * rank; + + shape::index2coords(rank, zShapeInfo + 1, tid, zLen, coords); + + const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank); + + int inArrIdx = 0; + Nd4jLong *xShapeInfo = reinterpret_cast(pxShapeInfo)[inArrIdx]; + + while(coords[axis] >= xShapeInfo[axis + 1]) { + coords[axis] -= xShapeInfo[axis + 1]; + xShapeInfo = reinterpret_cast(pxShapeInfo)[++inArrIdx]; + } + + const auto* x = reinterpret_cast(reinterpret_cast(pVx)[inArrIdx]); + const auto xOffset = shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank); + + z[zOffset] = x[xOffset]; +} + +/////////////////////////////////////////////////////////////////// +template +__host__ static void concatCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, + void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis) { + + concatCuda<<>>(pVx, pxShapeInfo, vz, zShapeInfo, axis); +} +BUILD_SINGLE_TEMPLATE(template void concatCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, void* pVx, void* pxShapeInfo, void* vz, Nd4jLong* zShapeInfo, const int axis), LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +void concat(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray& output, const int axis) { + + const int threadsPerBlock = MAX_NUM_THREADS / 4; + const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = threadsPerBlock * sizeof(Nd4jLong) * output.rankOf() + 128; + + const int numOfArrs = inArrs.size(); + + for(int i = 0; i < numOfArrs; ++i) + inArrs[i]->syncToDevice(); + + output.syncToDevice(); + + // prepare arrays of pointers on buffers and shapes + std::vector hInBuffers(numOfArrs); + std::vector hInShapeInfo(numOfArrs); + + for(int i = 0; i < numOfArrs; ++i) { + hInBuffers[i] = inArrs[i]->getSpecialBuffer(); + hInShapeInfo[i] = inArrs[i]->getSpecialShapeInfo(); + } + + PointersManager manager(context, "helpers::concat"); + + void* dInBuffers = manager.replicatePointer(hInBuffers.data(), hInBuffers.size() * sizeof(void*)); + void* dInShapeInfo = manager.replicatePointer(hInShapeInfo.data(), hInShapeInfo.size() * sizeof(Nd4jLong*)); + + BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), concatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), dInBuffers, dInShapeInfo, output.specialBuffer(), output.specialShapeInfo(), axis), LIBND4J_TYPES); + + manager.synchronize(); + + for(int i = 0; i < numOfArrs; ++i) + inArrs[i]->tickReadDevice(); + + output.tickWriteDevice(); +} + +} +} } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu index eda19ccd8..52b059dad 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/histogram.cu @@ -25,7 +25,7 @@ namespace nd4j { namespace ops { namespace helpers { template - void _CUDA_G histogramKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, void *allocationPointer, void *reductionPointer, Nd4jLong numBins, double min_val, double max_val) { + void _CUDA_G histogramKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, void *allocationPointer, void *reductionPointer, Nd4jLong numBins, X* min_val, X* max_val) { int tid = blockIdx.x * blockDim.x + threadIdx.x; auto dx = reinterpret_cast(xBuffer); auto result = reinterpret_cast(zBuffer); @@ -42,19 +42,19 @@ namespace nd4j { } __syncthreads(); - Z binSize = (max_val - min_val) / (numBins); + X binSize = X((*max_val - *min_val) / numBins); for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - bins[e] = (Z) 0.0f; + bins[e] = (Z) 0; } __syncthreads(); - for (int e = tid; e < length; e+= blockDim.x * gridDim.x) { - int idx = (int) ((dx[e] - min_val) / binSize); - if (idx < 0) idx = 0; - else if (idx >= numBins) idx = numBins - 1; - - nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z) 1.0f); + for (int e = tid; e < length; e += blockDim.x * gridDim.x) { + int idx = int((dx[e] - *min_val) / binSize); + idx = math::nd4j_max(idx, 0); //atomicMax(&idx, 0);//atomicMax(&idx, 0); + idx = math::nd4j_min(idx, int(numBins - 1)); //atomicMin(&idx, int(numBins - 1)); + nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z)1); +// bins[idx]++; } __syncthreads(); @@ -82,7 +82,7 @@ namespace nd4j { // nullify shared memory for future accumulation for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - bins[e] = (Z) 0.0f; + bins[e] = (Z) 0; } // accumulate reduced bins @@ -90,7 +90,7 @@ namespace nd4j { Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins); for (int e = threadIdx.x; e < numBins; e += blockDim.x) { - bins[e] += ptrBuf[e]; + math::atomics::nd4j_atomicAdd(&bins[e], ptrBuf[e]); } } __syncthreads(); @@ -109,24 +109,26 @@ namespace nd4j { } template - static void histogram_(nd4j::LaunchContext *context, void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, double min_val, double max_val) { + static void histogram_(nd4j::LaunchContext *context, void *xBuffer, Nd4jLong *xShapeInfo, Nd4jLong *dxShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, void* min_val, void* max_val) { int numThreads = 256; int numBlocks = nd4j::math::nd4j_max(256, nd4j::math::nd4j_min(1, shape::length(xShapeInfo) / numThreads)); int workspaceSize = numBlocks * numBins; - auto tmp = NDArrayFactory::create('c',{workspaceSize}); + auto tmp = NDArrayFactory::create('c', {workspaceSize}); - histogramKernel<<getCudaStream()>>>(xBuffer, xShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, min_val, max_val); + histogramKernel<<getCudaStream()>>>(xBuffer, dxShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, reinterpret_cast(min_val), reinterpret_cast(max_val)); cudaStreamSynchronize(*context->getCudaStream()); } void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output) { Nd4jLong numBins = output.lengthOf(); - double min_val = input.reduceNumber(reduce::SameOps::Min).e(0); - double max_val = input.reduceNumber(reduce::SameOps::Max).e(0); - - BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INDEXING_TYPES); + NDArray::registerSpecialUse({&output}, {&input}); + auto min_val = input.reduceNumber(reduce::SameOps::Min); + auto max_val = input.reduceNumber(reduce::SameOps::Max); +// min_val.printIndexedBuffer("MIN"); +// max_val.printIndexedBuffer("MAX"); + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.shapeInfo(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val.specialBuffer(), max_val.specialBuffer()), LIBND4J_TYPES, INTEGER_TYPES); NDArray::registerSpecialUse({&output}, {&input}); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu index 2cec0a065..0da1fbc28 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/image_suppression.cu @@ -68,21 +68,21 @@ namespace helpers { static __global__ void shouldSelectKernel(T* boxesBuf, Nd4jLong* boxesShape, I* indexBuf, I* selectedIndicesData, double threshold, int numSelected, int i, bool* shouldSelect) { auto tid = blockIdx.x * blockDim.x + threadIdx.x; auto step = gridDim.x * blockDim.x; - __shared__ bool shouldSelectShared; + __shared__ unsigned int shouldSelectShared; if (threadIdx.x == 0) { - shouldSelectShared = shouldSelect[0]; + shouldSelectShared = (unsigned int)shouldSelect[0]; } __syncthreads(); for (int j = numSelected - 1 - tid; j >= 0; j -= step) { if (shouldSelectShared) { if (needToSuppressWithThreshold(boxesBuf, boxesShape, indexBuf[i], indexBuf[selectedIndicesData[j]], T(threshold))) - shouldSelectShared = false; + atomicCAS(&shouldSelectShared, 1, 0); } } __syncthreads(); if (threadIdx.x == 0) { - *shouldSelect = shouldSelectShared; + *shouldSelect = shouldSelectShared > 0; } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu index b43a29bd8..26ed9780c 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/ismax.cu @@ -34,11 +34,6 @@ namespace helpers { template static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector& dimensions) { - void* extraParams = nullptr; - bool scalarCheat = false; - if (extraParams == nullptr) { - scalarCheat = true; - } auto stream = context->getCudaStream(); auto xRank = input->rankOf(); @@ -49,29 +44,16 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* Nd4jLong* special = nullptr; PointersManager manager(context, "IsMaxHelper"); if (dimensions.size() == 0) { -// auto scalarShape = ShapeBuilders::createScalarShapeInfo(nd4j::DataType::INT64); /** - * In case of vector-input for IsMax, it just turns into IndexReduce call + further filler call + * In case of vector-input for IsMax, it just turns into IndexReduce call + subsequent filler call */ auto indexMax = input->applyIndexReduce(indexreduce::IndexMax, dimensions); - //NativeOpExecutioner::execIndexReduceScalar(context, indexreduce::IndexMax, nullptr, input->getShapeInfo(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), extraParams, nullptr, scalarShape, special, nullptr); - //Nd4jLong maxIdx = -119; - //checkCudaErrors(cudaStreamSynchronize(*stream)); - //cudaMemcpyAsync(&maxIdx, special, sizeof(Nd4jLong), cudaMemcpyDeviceToHost, *stream); - //checkCudaErrors(cudaStreamSynchronize(*stream)); - int targetIdx = 0; + auto targetIdx = indexMax->e(0); - if (input->ordering() == 'c' || input->ordering() == 'f' && indexMax->e(0) * shape::stride(input->getShapeInfo())[input->rankOf() - 1] >= input->lengthOf()) - targetIdx = indexMax->e(0); - else - targetIdx = indexMax->e(0) * shape::stride(input->getShapeInfo())[input->rankOf() - 1]; + dim3 launchDims(128, 512, 1024); + BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, output->specialBuffer(), output->specialShapeInfo(), output->lengthOf(), targetIdx), LIBND4J_TYPES); + manager.synchronize(); - dim3 launchDims(1, 512, 1024); - BUILD_SINGLE_SELECTOR(zType, fillIsMaxGeneric, (launchDims, stream, output->specialBuffer(), output->lengthOf(), targetIdx), LIBND4J_TYPES); - - nd4j::DebugHelper::checkErrorCode(stream, "Legacy IsMax(...) failed"); - - //delete[] scalarShape; delete indexMax; } else { Nd4jLong* hostYShapeInfo = nullptr; @@ -82,13 +64,7 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), copy.data(), copy.size()); - auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions); - //indexMaxArr->printIndexedBuffer("Index max!!!"); - // we call for IMax on specified dimension - //NativeOpExecutioner::execIndexReduce(context, indexreduce::IndexMax, nullptr, input->getShapeInfo(), input->getSpecialBuffer(), input->getSpecialShapeInfo(), extraParams, nullptr, hostTShapeInfo, special, hostYShapeInfo, const_cast(dimensions.data()), (int)dimensions.size(), nullptr, nullptr); - - //DEBUG_KERNEL(stream, opNum); dim3 launchDims(256, 256, 16384); dimension = (int *) manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int)); @@ -103,7 +79,11 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray* void ismax(nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector& dimensions) { + NDArray::prepareSpecialUse({output}, {input}); + BUILD_SINGLE_SELECTOR(input->dataType(), ismax_, (context, input, output, dimensions), LIBND4J_TYPES); + + NDArray::registerSpecialUse({output}, {input}); } BUILD_SINGLE_TEMPLATE(template void ismax_, (nd4j::LaunchContext * context, const NDArray *input, NDArray *output, const std::vector& dimensions), LIBND4J_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu index ceb748453..082472fce 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/merge.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/merge.cu @@ -48,8 +48,10 @@ namespace nd4j { auto x = reinterpret_cast(inArrs[i]); auto xShape = reinterpret_cast(inShapes[i]); auto val = x[shape::getIndexOffset(e, xShape, length)];; - if (mVal < val) - mIdx = static_cast(e); + if (mVal < val) { + mIdx = static_cast(i); + mVal = val; + } } __syncthreads(); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu index 6bdd87650..216c6b7a0 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/roll.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/roll.cu @@ -228,22 +228,23 @@ namespace helpers { } template - static void rollFunctorFull_(NDArray* input, NDArray* output, int shift, std::vector const& axis, bool inplace){ + static void rollFunctorFull_(NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace){ if (!inplace) output->assign(input); - for (int axe: axis) { + for (size_t i = 0; i < axes.size(); i++) { + int axe = axes[i]; if (axe == input->rankOf() - 1) { // last dimension std::unique_ptr listOfTensors(output->allTensorsAlongDimension({axe})); std::unique_ptr listOfOutTensors(output->allTensorsAlongDimension({axe})); int fullLen = listOfTensors->size(); - int theShift = shift; - if (theShift > 0) { - theShift %= fullLen; - } - else { - theShift -= fullLen * (theShift / fullLen - 1); - } + int theShift = shifts[i]; +// if (theShift > 0) { +// theShift %= fullLen; +// } +// else { +// theShift -= fullLen * (theShift / fullLen - 1); +// } for (int k = 0; k < fullLen; k++) { rollFunctorLinear(output->getContext(), listOfTensors->at(k), listOfOutTensors->at(k), theShift, true); } @@ -258,12 +259,12 @@ namespace helpers { int sizeAt = input->sizeAt(axe); auto tadLength = shape::length(packZ.primaryShapeInfo()); - int theShift = shift; + int theShift = shifts[i]; - if (theShift > 0) - theShift %= sizeAt; - else - theShift -= sizeAt * (theShift / sizeAt - 1); +// if (theShift > 0) +// theShift %= sizeAt; +// else +// theShift -= sizeAt * (theShift / sizeAt - 1); if (theShift) { for (int dim = 0; dim < numTads / sizeAt; ++dim) { @@ -307,10 +308,10 @@ namespace helpers { } } - void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector const& axis, bool inplace){ + void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace){ input->syncToDevice(); - BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shift, axis, inplace), LIBND4J_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), rollFunctorFull_, (input, output, shifts, axes, inplace), LIBND4J_TYPES); output->tickWriteDevice(); } @@ -324,7 +325,7 @@ namespace helpers { } BUILD_SINGLE_TEMPLATE(template void rollFunctorLinear_, (NDArray* input, NDArray* output, int shift, bool inplace), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, int shift, std::vector const& axis, bool inplace), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void rollFunctorFull_, (NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace), LIBND4J_TYPES); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu b/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu index da212d287..6a9fd28e6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sg_cb.cu @@ -123,14 +123,236 @@ namespace nd4j { nSamplingKernel<<<1,1,128, *stream>>>(vsyn0, vsyn1Neg, vexpTable, vneu1e, alpha, vectorLength, code, expLength, isInference); } + /* + * binarySearch - find element in haystack buffer (haystack - sorted device memory) + * */ int binarySearch(const int *haystack, const int needle, const int totalElements) { - return 0; + int firstIndex = 0; + int lastIndex = totalElements - 1; + int halfIndex = nd4j::math::nd4j_floor((lastIndex + firstIndex) / (float) 2); + + while(haystack[halfIndex] != needle && firstIndex < lastIndex) { + if (needle < haystack[halfIndex]) { + lastIndex = halfIndex - 1; + } else if (needle > haystack[halfIndex]) { + firstIndex = halfIndex + 1; + } + halfIndex = nd4j::math::nd4j_floor((lastIndex + firstIndex) / (float) 2); + } + + return (haystack[halfIndex] == needle) ? halfIndex : -1; + } + template + __global__ void addInfVectorKernel(T* neu1, T* infVector, int vectorLength) { + auto start = blockIdx.x * blockDim.x + threadIdx.x; + auto step = blockDim.x * gridDim.x; + + for (auto i = start; i < vectorLength; i += step) { + neu1[i] += infVector[i]; + } } - void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &inferenceVector, const bool preciseMode, const int numWorkers) { + template + void skipgram_(NDArray& s0, NDArray& s1, NDArray& s1n, NDArray& expTableV, NDArray& negTableV, NDArray& infV, int target, int ngStarter, NDArray& indices, NDArray& codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds) { +// void *vsyn0, void *vsyn1, void *vsyn1Neg, void *vexpTable, void *vnegTable, void *vinfVector, int target, int ngStarter, int *indices, int8_t *codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds, const int vocabSize, const int vectorLength, const int expLength, const int negLength) { + auto syn0 = reinterpret_cast(s0.specialBuffer()); + auto syn1 = reinterpret_cast(s1.specialBuffer()); + auto syn1Neg = reinterpret_cast(s1n.specialBuffer()); + auto expTable = reinterpret_cast(expTableV.specialBuffer()); + auto negTable = reinterpret_cast(negTableV.specialBuffer()); + auto infVector = reinterpret_cast(infV.specialBuffer()); + const int vocabSize = s0.sizeAt(0); + const int vectorLength = s0.sizeAt(1); + const int expLength = expTableV.lengthOf(); + const int negLength = negTableV.lengthOf(); + indices.tickReadDevice(); + indices.syncToHost(); + codes.tickReadDevice(); + codes.syncToHost(); + auto stream = s0.getContext()->getCudaStream(); + + T* neu1e; // = new T[vectorLength]; + //memset(neu1e, 0, vectorLength * sizeof(T)); + auto err = cudaMalloc(&neu1e, sizeof(T) * vectorLength); + err = cudaMemset(neu1e, 0, sizeof(T) * vectorLength); + // hierarchic softmax goes first (if enabled) + + auto syn0row = infVector != nullptr ? infVector : syn0 + (target * vectorLength); + auto irow = 0; + if (hsRounds > 0) { + for (int r = 0; r < hsRounds; r++) { + irow = indices.t(r); + if (irow < 0 || irow >= vocabSize) + break; + + hSoftmax_(syn0row, syn1 + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, codes.t(r), expLength, infVector != nullptr, stream); + } + } + + // negative sampling goes second (if enabled) + auto nsStarter = ngStarter; + irow = nsStarter; + if (nsRounds > 0) { + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long) 25214903917 + 11; + auto idx = nd4j::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : negTableV.e(idx); + + if (irow < 0 || irow >= vocabSize) irow = randomValue % (vocabSize - 1) + 1; + if (irow == nsStarter) + continue; + } + + nSampling_(syn0row, syn1Neg + (irow * vectorLength), expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, infVector != nullptr, stream); + } + } + + if (infVector == nullptr) { + addInfVectorKernel<<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength); + } else { + addInfVectorKernel<<<128, 256, 256, *stream>>>(infVector, neu1e, vectorLength); + } + + err = cudaFree(neu1e); + if (0 != err) { + throw cuda_exception::build("helpers::skipgram_: Cannot deallocate temp memory for lingual net", err); + } + } + BUILD_SINGLE_TEMPLATE(template void skipgram_, (NDArray& syn0, NDArray& syn1, NDArray& syn1Neg, NDArray& expTable, NDArray& negTable, NDArray& infVector, int target, int ngStarter, NDArray& indices, NDArray& codes, double alpha, Nd4jLong randomValue, const int hsRounds, const int nsRounds), FLOAT_TYPES); + + /* + * batched version of skipgram routine + * */ + template + void skipgramBatchExec_(NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTableV, NDArray& negTableV, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) { +// (NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTable, NDArray& negTable, NDArray& infVector, NDArray& targets, NDArray& negStarters, NDArray& indices, NDArray& codes, NDArray& lr, NDArray& nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) { + //auto syn0 = reinterpret_cast(vsyn0); + //auto syn1 = reinterpret_cast(vsyn1); + //auto syn1Neg = reinterpret_cast(vsyn1Neg); + auto stream = s0.getContext()->getCudaStream(); + negTableV.tickReadDevice(); + negTableV.syncToHost(); + const auto expTable = reinterpret_cast(expTableV.specialBuffer()); + const auto negTable = reinterpret_cast(negTableV.buffer()); + const auto infVector = (T*)nullptr; //reinterpret_cast(infVector.specialBuffer()); + + const int vocabSize = s0.sizeAt(0); + const int vectorLength = s0.sizeAt(1); + const int expLength = expTableV.lengthOf(); + const int negLength = negTableV.lengthOf(); + + //T sneu1e[600]; + + //const auto numThreads = omp_get_max_threads(); + const auto idxShift = indices.isEmpty() ? 0 : indices.sizeAt(1); + const auto hsRounds = codes.isEmpty() ? 0 : codes.sizeAt(1); + + // regular mode provides 0 guarantees for reproducibility + auto numTargets = targets.lengthOf(); + targets.syncToHost(); + indices.syncToHost(); + codes.syncToHost(); + lr.syncToHost(); + nextRandom.syncToHost(); + negStarters.tickReadDevice(); + negStarters.syncToHost(); + auto bTarget = reinterpret_cast(targets.buffer()); //targets.bufferAsT(); + auto bIndices = reinterpret_cast(indices.buffer()); //indices.bufferAsT(); + auto bCodes = reinterpret_cast(codes.buffer()); //codes.bufferAsT(); + +// PRAGMA_OMP_PARALLEL_FOR_ARGS(num_threads(numThreads)) + for (int t = 0; t < numTargets; t++) { + T* neu1e;//lvectorLength <= 600 ? sneu1e : new T[vectorLength]; + auto err = cudaMalloc(&neu1e, vectorLength * sizeof(T)); + err = cudaMemset(neu1e, 0, vectorLength * sizeof(T)); + //memset(neu1e, 0, vectorLength * sizeof(T)); + + auto target = bTarget[t]; + auto alpha = lr.e(t); + unsigned long long randomValue = nextRandom.e(t); + + auto syn0row = reinterpret_cast(s0.specialBuffer()) + (target * vectorLength); + + if (hsRounds > 0) { + int irow = 0; + auto cShift = t * idxShift; + + for (int e = 0; e < hsRounds; e++) { + irow = bIndices[e + cShift]; + if (irow < 0 || irow >= vocabSize) + continue; + + auto syn1row = reinterpret_cast(s1.getSpecialBuffer()) + (irow * vectorLength); + auto code = bCodes[e + cShift]; + + //nd4j_printf("syn0: [%i]; syn1: [%i]; code: [%i]\n", target, irow, code); + hSoftmax_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, code, expLength, false, stream); + } + } + + + if (nsRounds > 0) { + int irow = negStarters.e(t); + int nsStarter = irow; + for (int r = 0; r < nsRounds + 1; r++) { + if (r == 0) { + // target is known in advance + } else { + randomValue = randomValue * (unsigned long long) 25214903917 + 11; + auto idx = nd4j::math::nd4j_abs((randomValue >> 16) % negLength); + irow = idx >= negLength ? -1 : static_cast(negTable[idx]); + + if (irow < 0 || irow >= vocabSize) + irow = randomValue % (vocabSize - 1) + 1; + + if (irow == nsStarter) + continue; + } + auto syn1row = reinterpret_cast(s1n.getSpecialBuffer()) + (irow * vectorLength); + + nSampling_(syn0row, syn1row, expTable, neu1e, alpha, vectorLength, r == 0 ? 1 : 0, expLength, false, stream); + } + } + addInfVectorKernel<<<128, 256, 256, *stream>>>(syn0row, neu1e, vectorLength); + + // optionally release temp arrays + err = cudaFree(neu1e); + if (err != 0) { + break; + } +// if (vectorLength > 600) +// delete[] neu1e; + } + } + BUILD_SINGLE_TEMPLATE(template void skipgramBatchExec_, (NDArray &s0, NDArray &s1, NDArray &s1n, NDArray& expTable, NDArray& negTable, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads), FLOAT_TYPES); + + void skipgram(NDArray &syn0, NDArray &syn1, NDArray &syn1Neg, NDArray &expTable, NDArray &negTable, + NDArray &target, NDArray &ngStarter, int nsRounds, NDArray &indices, NDArray &codes, NDArray &alpha, NDArray &randomValue, NDArray &inferenceVector, const bool preciseMode, const int numWorkers) { auto xType = syn0.dataType(); - + // single round case + if ((ngStarter.isScalar() && !ngStarter.isEmpty())|| (target.isScalar() && !target.isEmpty())) { + auto hsRounds = codes.lengthOf(); + target.syncToHost(); + ngStarter.syncToHost(); + alpha.syncToHost(); + randomValue.syncToHost(); + + auto targetV = target.isEmpty() ? -1 : target.e(0); + auto starterV = ngStarter.isEmpty() ? -1 : ngStarter.e(0); + auto alphaV = alpha.e(0); + auto randomV = randomValue.e(0); + BUILD_SINGLE_SELECTOR(xType, skipgram_, (syn0, syn1, syn1Neg, expTable, negTable, inferenceVector, targetV, starterV, indices, codes, alphaV, randomV, hsRounds, nsRounds), FLOAT_TYPES); + } else if (ngStarter.isVector() || target.isVector()){ + // batch mode +// NDArray& infVector, NDArray &targets, NDArray &negStarters, NDArray &indices, NDArray &codes, NDArray &lr, NDArray &nextRandom, const int nsRounds, const bool preciseMode, const int numThreads) + BUILD_SINGLE_SELECTOR(xType, skipgramBatchExec_, (syn0, syn1, syn1Neg, expTable, negTable, target, ngStarter, indices, codes, alpha, randomValue, nsRounds, preciseMode, numWorkers), FLOAT_TYPES); + } else + throw std::runtime_error("SkipGram: target must have rank 0 or 1"); } + template static __global__ void checkContextKernel(int* context, T* syn0, T* neu1, int contextWidth, int vectorLength, int vocabSize) { __shared__ bool hasError; @@ -157,16 +379,6 @@ namespace nd4j { } } - template - __global__ void addInfVectorKernel(T* neu1, T* infVector, int vectorLength) { - auto start = blockIdx.x * blockDim.x + threadIdx.x; - auto step = blockDim.x * gridDim.x; - - for (auto i = start; i < vectorLength; i += step) { - neu1[i] += infVector[i]; - } - } - template __global__ void shiftKernel(T* neu1, T* infVector, int contextWidth, int vectorLength) { auto start = blockIdx.x * blockDim.x + threadIdx.x; diff --git a/libnd4j/include/ops/declarable/helpers/cuda/shift.cu b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu new file mode 100644 index 000000000..49d388b2a --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/shift.cu @@ -0,0 +1,81 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include + +namespace nd4j { + namespace ops { + namespace helpers { + template + void rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { + return x >> shift; + }; + + input.applyLambda(lambda, &output); + } + + void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + + template + void shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto lambda = LAMBDA_T(x, shift) { + return x << shift; + }; + + input.applyLambda(lambda, &output); + } + + void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + + template + void cyclic_rshift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { + return x >> shift | x << step; + }; + + input.applyLambda(lambda, &output); + } + + void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_rshift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + + template + void cyclic_shift_bits_(LaunchContext* launchContext, NDArray &input, NDArray &output, uint32_t shift) { + auto step = (sizeof(T) * 8) - shift; + auto lambda = LAMBDA_T(x, shift, step) { + return x << shift | x >> step; + }; + + input.applyLambda(lambda, &output); + } + + void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift) { + BUILD_SINGLE_SELECTOR(x.dataType(), cyclic_shift_bits_, (launchContext, x, z, shift), INTEGER_TYPES); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu index c8db7b0e1..f90c9f77f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/toggle_bits.cu @@ -26,7 +26,13 @@ namespace nd4j { namespace helpers { template void toggle_bits__(NDArray &in, NDArray &out) { + NDArray::prepareSpecialUse({&out}, {&in}); + auto lambda = LAMBDA_T(_x) { + return ~_x;//eUtils::flip_bits(_x); + }; + in.applyLambda(lambda, &out); + NDArray::registerSpecialUse({&out}, {&in}); } BUILD_SINGLE_TEMPLATE(template void toggle_bits__, (NDArray &in, NDArray &out), INTEGER_TYPES); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index bb311ed01..5406e8bbd 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -685,13 +685,12 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, nd4j::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES); - ////////////////////////////////////////////////////////////////////////// void eye(nd4j::LaunchContext * context, NDArray& output) { + output.setIdentity(); } - //////////////////////////////////////////////////////////////////////////////////////////////////////////////////// template static __global__ void clipByNormInplaceKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong* shape, Nd4jLong* inputOffsets, T* norm2Buf, Nd4jLong* norm2shape, T clipNorm) { @@ -807,7 +806,6 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr void clipByGlobalNorm_(nd4j::LaunchContext * context, std::vector const& inputs, double clipNorm, nd4j::memory::Workspace* workspace, std::vector& outputs, bool isInplace) { NDArray globalNorm = NDArrayFactory::create(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list])) - PRAGMA_OMP_PARALLEL_FOR for (auto i = 0; i < inputs.size(); i++) { auto input = inputs[i]; auto l2norm = input->reduceNumber(reduce::Norm2); @@ -819,7 +817,6 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr globalNorm.syncToHost(); const T factor = clipNorm / globalNorm.e(0); - PRAGMA_OMP_PARALLEL_FOR for (size_t e = 0; e < inputs.size(); e++) { // all-reduce auto input = inputs[e]; diff --git a/libnd4j/include/ops/declarable/helpers/roll.h b/libnd4j/include/ops/declarable/helpers/roll.h index ff6c67a57..b20367c0d 100644 --- a/libnd4j/include/ops/declarable/helpers/roll.h +++ b/libnd4j/include/ops/declarable/helpers/roll.h @@ -26,7 +26,7 @@ namespace ops { namespace helpers { void rollFunctorLinear(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, bool inplace = false); - void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int shift, std::vector const& axes, bool inplace = false); + void rollFunctorFull(nd4j::LaunchContext * context, NDArray* input, NDArray* output, std::vector const& shifts, std::vector const& axes, bool inplace = false); } } } diff --git a/libnd4j/include/ops/declarable/helpers/shift.h b/libnd4j/include/ops/declarable/helpers/shift.h new file mode 100644 index 000000000..e07a0e992 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/shift.h @@ -0,0 +1,42 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef DEV_TESTS_SHIFT_H +#define DEV_TESTS_SHIFT_H + +#include +#include +#include + +namespace nd4j { + namespace ops { + namespace helpers { + void rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); + + void shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); + + void cyclic_rshift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); + + void cyclic_shift_bits(LaunchContext* launchContext, NDArray &x, NDArray &z, uint32_t shift); + } + } +} + +#endif //DEV_TESTS_SHIFT_H diff --git a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp index 501a29b8c..9de87b584 100644 --- a/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/DataTypesValidationTests.cpp @@ -113,6 +113,14 @@ TEST_F(DataTypesValidationTests, test_bfloat16_rand_1) { ASSERT_TRUE(x.sumNumber().e(0) > 0); } +TEST_F(DataTypesValidationTests, test_bfloat16_rand_2) { + auto x = NDArrayFactory::create('c', {5, 10}); + RandomGenerator gen(119, 120); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), gen, &x, 0, 1); + + ASSERT_TRUE(x.sumNumber().e(0) > 0); +} + TEST_F(DataTypesValidationTests, cast_1) { float16 x = static_cast(1.f); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp index 6fe3dfac6..7fbc309d5 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -750,59 +750,6 @@ TEST_F(DeclarableOpsTests12, tensormmul_6) { } -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests12, concat_test10) { - - NDArray x0('c', {1,4,5}, nd4j::DataType::FLOAT32); - NDArray x1('c', {2,4,5}, nd4j::DataType::FLOAT32); - NDArray z('f', {3,4,5}, nd4j::DataType::FLOAT32); - - x0 = 0.; - x1 = 1.; - - nd4j::ops::concat op; - auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); -} - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests12, concat_14) { - - NDArray x0('c', {1,6}, {1,2,3,4,5,6}); - NDArray x1('c', {1,6}, {7,8,9,10,11,12}); - NDArray output('f', {2,6}, nd4j::DataType::DOUBLE); - NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}); - - nd4j::ops::concat op; - - auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); - // output.printBuffer(); - // output.printIndexedBuffer(); - - ASSERT_TRUE(exp.equalsTo(output)); -} - - -//////////////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests12, concat_15) { - - NDArray x0('c', {1,4}, {1,2,3,4}); - NDArray x1('c', {1,4}, {5,6,7,8}); - NDArray output('c', {2,4}, nd4j::DataType::DOUBLE); - NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}); - - nd4j::ops::concat op; - - auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); - ASSERT_EQ(ND4J_STATUS_OK, status); - // output.printBuffer(); - // output.printIndexedBuffer(); - - ASSERT_TRUE(exp.equalsTo(output)); -} - - //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests12, reduceMeanBp_4) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp index 81e441477..014719270 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -33,8 +33,8 @@ class DeclarableOpsTests13 : public testing::Test { public: DeclarableOpsTests13() { - printf("\n"); - fflush(stdout); + //printf("\n"); + //fflush(stdout); } }; @@ -103,8 +103,9 @@ TEST_F(DeclarableOpsTests13, test_argmax_edge_1) { nd4j::ops::argmax op; auto result = op.execute(ctx); + ASSERT_EQ(Status::OK(), result); - nd4j_printf("Done\n",""); + //nd4j_printf("Done\n",""); delete ctx; } @@ -258,7 +259,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) { ASSERT_EQ(result->status(), Status::OK()); - result->at(0)->printBuffer("Output"); + //result->at(0)->printBuffer("Output"); ASSERT_TRUE(exp1.equalsTo(result->at(0))); delete result; } @@ -306,8 +307,8 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) { //nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf()); ASSERT_EQ(result->status(), Status::OK()); - result->at(0)->printBuffer("Output"); - exp.printBuffer("Expect"); + //result->at(0)->printBuffer("Output"); + //exp.printBuffer("Expect"); //result->at(0)->printShapeInfo("Shape output"); ASSERT_TRUE(exp.equalsTo(result->at(0))); delete result; @@ -327,7 +328,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) { nd4j::ops::barnes_symmetrized op; auto result = op.execute({&rows, &cols, &vals}, {}, {1}); ASSERT_EQ(result->status(), Status::OK()); - result->at(2)->printBuffer("Symmetrized1"); + //result->at(2)->printBuffer("Symmetrized1"); ASSERT_TRUE(exp.equalsTo(result->at(2))); delete result; @@ -346,7 +347,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) { nd4j::ops::barnes_symmetrized op; auto result = op.execute({&rows, &cols, &vals}, {}, {3}); ASSERT_EQ(result->status(), Status::OK()); - result->at(2)->printBuffer("Symmetrized2"); + //result->at(2)->printBuffer("Symmetrized2"); // ASSERT_TRUE(exp[i]->equalsTo(result->at(i))); ASSERT_TRUE(exp.equalsTo(result->at(2))); delete result; @@ -365,7 +366,7 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) { nd4j::ops::barnes_symmetrized op; auto result = op.execute({&rows, &cols, &vals}, {}, {11}); ASSERT_EQ(result->status(), Status::OK()); - result->at(2)->printBuffer("Symmetrized3"); + //result->at(2)->printBuffer("Symmetrized3"); //exp.printBuffer("EXPect symm3"); // ASSERT_TRUE(exp[i]->equalsTo(result->at(i))); //ASSERT_TRUE(exp.equalsTo(result->at(0))); @@ -390,10 +391,10 @@ TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) { auto result = op.execute({&rows, &cols, &vals}, {}, {11}); ASSERT_EQ(result->status(), Status::OK()); auto res = result->at(2); - res->printBuffer("Symmetrized4"); - exp4.printBuffer("Expected sym"); - nd4j_printf("Total res is {1, %lld}\n", res->lengthOf()); - nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf()); + // res->printBuffer("Symmetrized4"); + // exp4.printBuffer("Expected sym"); + // nd4j_printf("Total res is {1, %lld}\n", res->lengthOf()); + // nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf()); //exp.printBuffer("EXPect symm3"); // ASSERT_TRUE(exp[i]->equalsTo(result->at(i))); @@ -619,3 +620,72 @@ TEST_F(DeclarableOpsTests13, adjustSaturation_5) { delete results; } + +TEST_F(DeclarableOpsTests13, shift_bits_1) { + auto x = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + e.assign(512); + + nd4j::ops::shift_bits op; + auto result = op.execute({&x}, {}, {4}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests13, rshift_bits_1) { + auto x = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + e.assign(32); + + nd4j::ops::rshift_bits op; + auto result = op.execute({&x}, {}, {4}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { + auto x = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + e.assign(512); + + nd4j::ops::cyclic_shift_bits op; + auto result = op.execute({&x}, {}, {4}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + +TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) { + auto x = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + e.assign(32); + + nd4j::ops::cyclic_rshift_bits op; + auto result = op.execute({&x}, {}, {4}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + ASSERT_EQ(e, *z); + + delete result; +} + diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 06d677b27..df1421d71 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -364,77 +364,6 @@ TEST_F(DeclarableOpsTests15, test_rank_2) { delete result; } -TEST_F(DeclarableOpsTests15, test_concat_column_1) { - auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); - auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); - auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); - auto z = NDArrayFactory::create('c', {2, 2}); - - nd4j::ops::concat op; - auto status = op.execute({&x, &y}, {&z}, {}, {1}, {}); - ASSERT_EQ(Status::OK(), status); - - z.printIndexedBuffer("z"); - - ASSERT_EQ(e, z); -} - -TEST_F(DeclarableOpsTests15, test_concat_large_1) { - std::array arrays; - Context context(1); - Nd4jLong axis = 0; - - // we crate bunch of arrays, filled with specific values - for (int e = 0; e < arrays.size(); e++) { - auto array = NDArrayFactory::create_('c', {1, 300}); - array->assign(e); - context.setInputArray(e, array, true); - } - - auto z = NDArrayFactory::create('c', {2000, 300}); - context.setOutputArray(0, &z, false); - context.setIArguments(&axis, 1); - - nd4j::ops::concat op; - op.execute(&context); - - for (int e = 0; e < arrays.size(); e++) { - auto row = z.tensorAlongDimension(e, {1}); - - ASSERT_NEAR((float) e, row->e(0), 1e-5f); - - delete row; - } -} - -TEST_F(DeclarableOpsTests15, test_concat_large_2) { - std::array arrays; - Context context(1); - Nd4jLong axis = 0; - - // we crate bunch of arrays, filled with specific values - for (int e = 0; e < arrays.size(); e++) { - auto array = NDArrayFactory::create_('c', {1, 5, 20}); - array->assign(e); - context.setInputArray(e, array, true); - } - - auto z = NDArrayFactory::create('c', {arrays.size(), 5, 20}); - context.setOutputArray(0, &z, false); - context.setIArguments(&axis, 1); - - nd4j::ops::concat op; - op.execute(&context); - - for (int e = 0; e < arrays.size(); e++) { - auto row = z.tensorAlongDimension(e, {1, 2}); - - ASSERT_NEAR((float) e, row->meanNumber().e(0), 1e-5f); - - delete row; - } -} - TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { auto x0 = NDArrayFactory::create(5); auto x1 = NDArrayFactory::create('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp new file mode 100644 index 000000000..5b17b684a --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -0,0 +1,54 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace nd4j; + + +class DeclarableOpsTests16 : public testing::Test { +public: + + DeclarableOpsTests16() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests16, test_repeat_119) { + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto e = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); + + nd4j::ops::repeat op; + auto result = op.execute({&x}, {}, {2, 0}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + ASSERT_EQ(e, *z); + + delete result; +} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp index c2af3cef4..62172dbf2 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -373,35 +373,6 @@ TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) { delete result; } -TEST_F(DeclarableOpsTests2, Test_Concat_3D_1) { - auto x0 = NDArrayFactory::create('c', {1, 100, 150}); - auto x1 = NDArrayFactory::create('c', {1, 100, 150}); - auto x2 = NDArrayFactory::create('c', {1, 100, 150}); - auto x3 = NDArrayFactory::create('c', {1, 100, 150}); - - x0.assign(1.0); - x1.assign(2.0); - x2.assign(3.0); - x3.assign(4.0); - - nd4j::ops::concat op; - auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0}); - ASSERT_TRUE(4 == numOfTads); - - for (int e = 0; e < numOfTads; e++) { - NDArray tad = (*z)(e, {0}); - auto mean = tad.meanNumber().e(0); - ASSERT_NEAR((double) e+1, mean, 1e-5); - } - - delete result; -} - TEST_F(DeclarableOpsTests2, YetAnotherMatmulTest_1) { auto A = NDArrayFactory::create('c', {3, 3}); auto B = NDArrayFactory::create('c', {3, 1}); @@ -502,6 +473,7 @@ TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) { auto x = NDArrayFactory::create('c', {1, 3}, {3.0, 6.0, -3.0}); auto y = NDArrayFactory::create('c', {1, 3}, {-2.0, 2.0, -2.0}); auto eps = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto exp1 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); auto exp2 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 2e1833548..b596ebcd5 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -223,21 +223,221 @@ TEST_F(DeclarableOpsTests5, Test_Boolean_diff_1) { delete result; } +TEST_F(DeclarableOpsTests5, Test_SetSeed_1) { + auto x = NDArrayFactory::create('c', {1, 1}, {120}); + auto y = NDArrayFactory::create(5); + nd4j::ops::set_seed op; + auto result = op.execute({&x, &y}, {}, {120, 5}, {}, false, nd4j::DataType::INT32); + ASSERT_EQ(Status::OK(), result->status()); +// result->at(0)->printIndexedBuffer("RES SEED"); + nd4j::ops::get_seed getOp; + auto getRes = getOp.execute({}, {}, {}); + ASSERT_EQ(Status::OK(), getRes->status()); +// getRes->at(0)->printIndexedBuffer("Output RES GET SEED"); +// ASSERT_EQ(result->at(0)->t(0), true); + delete result; + delete getRes; +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, scatterMul_test1) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {10, 2, 3, 4}); + nd4j::ops::scatter_mul op; + auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto z = result->at(0); + ASSERT_TRUE(exp.equalsTo(z)); + delete result; +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, scatterDiv_test1) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.10, 2, 3, 4}); + nd4j::ops::scatter_div op; + auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto z = result->at(0); +// z->printIndexedBuffer("Scatter Div"); + ASSERT_TRUE(exp.equalsTo(z)); + delete result; +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, scatterSub_test1) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray idc('c', {1}, {0LL}, nd4j::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {-9, 1, 3, 4}); + nd4j::ops::scatter_sub op; + auto result = op.execute({&matrix, &idc, &updates}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto z = result->at(0); +// z->printIndexedBuffer("Scatter Sub"); + ASSERT_TRUE(exp.equalsTo(z)); + delete result; +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, hardsigmoid_test1) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.7, 0.9, 1, 1}); + + nd4j::ops::hardsigmoid op; + auto result = op.execute({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + z->printIndexedBuffer("Hadrdsigmoid 2x2"); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, hardsigmoid_test2) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto eps = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.2, 0.4, 0, 0}); + + nd4j::ops::hardsigmoid_bp op; + auto result = op.execute({&matrix, &eps}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + z->printIndexedBuffer("Hadrdsigmoid 2x2"); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, hardtanh_test1) { + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3, 3}, {-1, -1, -1, -1, 0, 1, 1, 1, 1}); + + nd4j::ops::hardtanh op; + auto result = op.execute({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("Hardtanh 2x2"); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, hardtanh_test2) { + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto eps = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0}); + + nd4j::ops::hardtanh_bp op; + auto result = op.execute({&matrix, &eps}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("Hardtanh_bp 2x2"); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, histogram_test1) { + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {3, 3, 3}); + + nd4j::ops::histogram op; + auto result = op.execute({&matrix}, {}, {3}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("Histogram3"); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, histogram_test2) { + auto matrix = NDArrayFactory::create('c', {3}, {1, 2, 1}); + auto exp = NDArrayFactory::create('c', {4}, {2, 0, 0, 1}); + + nd4j::ops::histogram op; + auto result = op.execute({&matrix}, {}, {4}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + z->printIndexedBuffer("Histogram4"); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Identity_test1) { + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); +// auto exp = NDArrayFactory::create('c', {3, 3}, {3, 3, 3}); + + nd4j::ops::identity op; + auto result = op.execute({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("Histogram3"); + ASSERT_TRUE(matrix.equalsTo(z)); + + delete result; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Identity_test2) { + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); +// auto exp = NDArrayFactory::create('c', {3,3}); + nd4j::ops::identity_bp op; + auto result = op.execute({&matrix, &eps}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + z->printIndexedBuffer("Identity_BP"); + ASSERT_TRUE(z->equalsTo(eps)); + + delete result; +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Log1p_test1) { + auto matrix = NDArrayFactory::create('c', {3, 3}, {4, 3, 2, 1, 0, 1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {3,3}, {5,4,3,2,1,2,3,4,5}); + // auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); +// auto exp = NDArrayFactory::create('c', {3,3}); + nd4j::ops::Log1p op; + y.applyTransform(nd4j::transform::Log, nullptr, nullptr); + auto result = op.execute({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + + auto z = result->at(0); + z->printIndexedBuffer("Log1p"); + ASSERT_TRUE(z->equalsTo(y)); + + delete result; +} TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp index 24701f70f..48996f2a5 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -737,6 +737,44 @@ TEST_F(DeclarableOpsTests6, cumSum_20) { delete result; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto z = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); + nd4j::ops::mergemaxindex op; + + auto ress = op.execute({&x, &y, &z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, ress->status()); +// ress->at(0)->printIndexedBuffer("MergeMaxIndex Result is "); +// ress->at(0)->printShapeInfo("Shape info for MergeMaxIdex"); +// x.printIndexedBuffer("Input is"); + ASSERT_TRUE(ress->at(0)->equalsTo(exp)); + delete ress; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto z = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); + nd4j::ops::mergemaxindex op; + + auto ress = op.execute({&x, &y, &z}, {}, {nd4j::DataType::INT64}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, ress->status()); +// ress->at(0)->printIndexedBuffer("MergeMaxIndex2 Result is "); +// ress->at(0)->printShapeInfo("Shape info for MergeMaxIdex2"); +// x.printIndexedBuffer("Input is"); + ASSERT_TRUE(ress->at(0)->equalsTo(exp)); + delete ress; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests6, TestDropout_1) { @@ -752,8 +790,60 @@ TEST_F(DeclarableOpsTests6, TestDropout_1) { delete ress; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestMod_1) { + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0}); + nd4j::ops::mod op; + auto ress = op.execute({&x, &y}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, ress->status()); +// ress->at(0)->printIndexedBuffer("MOD Result is "); +// x.printIndexedBuffer("Input is"); + ASSERT_TRUE(ress->at(0)->equalsTo(exp)); + delete ress; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestMod_BP_1) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto eps = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}); + nd4j::ops::mod_bp op; + + auto ress = op.execute({&x, &y, &eps}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, ress->status()); +// ress->at(0)->printIndexedBuffer("MOD_BP Result is "); + + // x.printIndexedBuffer("Input is"); + ASSERT_TRUE(ress->at(0)->equalsTo(exp)); + delete ress; +} + +/////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestRank_1) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto eps = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto exp = NDArrayFactory::create(3); + nd4j::ops::rank op; + + auto ress = op.execute({&x}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, ress->status()); + ress->at(0)->printIndexedBuffer("RANK Result is "); + + // x.printIndexedBuffer("Input is"); + ASSERT_TRUE(ress->at(0)->equalsTo(exp)); + delete ress; +} TEST_F(DeclarableOpsTests6, TestDropout_2) { // auto x0 = NDArrayFactory::create('c', {10, 10}); // auto x1 = NDArrayFactory::create('c', {10, 10}); @@ -1480,8 +1570,8 @@ TEST_F(DeclarableOpsTests6, LogDet_1) { ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); - z->printIndexedBuffer("LogDet Output1 "); - exp.printIndexedBuffer("LogDet Expected1 "); +// z->printIndexedBuffer("LogDet Output1 "); +// exp.printIndexedBuffer("LogDet Expected1 "); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1502,9 +1592,9 @@ TEST_F(DeclarableOpsTests6, LogDet_2) { ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); - z->printIndexedBuffer("LogDet Output2 "); +// z->printIndexedBuffer("LogDet Output2 "); // z->printShapeInfo("Shape"); - exp.printIndexedBuffer("LogDet Expected2 "); +// exp.printIndexedBuffer("LogDet Expected2 "); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1525,9 +1615,9 @@ TEST_F(DeclarableOpsTests6, LogDet_3) { ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); - z->printIndexedBuffer("LogDet Output3 "); +// z->printIndexedBuffer("LogDet Output3 "); // z->printShapeInfo("Shape"); - exp.printIndexedBuffer("LogDet Expected3 "); +// exp.printIndexedBuffer("LogDet Expected3 "); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1572,8 +1662,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) { ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); - z->printIndexedBuffer("Output "); - exp.printIndexedBuffer("Expected "); +// z->printIndexedBuffer("Output "); +// exp.printIndexedBuffer("Expected "); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1608,8 +1698,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_01) { ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); - z->printIndexedBuffer("Output "); - exp.printIndexedBuffer("Expected "); +// z->printIndexedBuffer("Output "); +// exp.printIndexedBuffer("Expected "); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1642,8 +1732,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_02) { ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); - z->printIndexedBuffer("Output "); - exp.printIndexedBuffer("Expected "); +// z->printIndexedBuffer("Output "); +// exp.printIndexedBuffer("Expected "); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1722,8 +1812,8 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) { ASSERT_EQ(ND4J_STATUS_OK, result->status()); auto z = result->at(0); - z->printIndexedBuffer("Output "); - exp.printIndexedBuffer("Expected "); +// z->printIndexedBuffer("Output "); +// exp.printIndexedBuffer("Expected "); ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -2755,31 +2845,4 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) { delete result; } -TEST_F(DeclarableOpsTests6, concat_test14) { - - NDArray x0('c', {1, 55, 40}, nd4j::DataType::DOUBLE); - NDArray x1('c', {1, 55, 40}, nd4j::DataType::DOUBLE); - - x0 = 1.; - x1 = 2.; - - nd4j::ops::concat op; - auto result = op.execute({&x0, &x1}, {}, {0}, {}); - ASSERT_EQ(Status::OK(), result->status()); - - auto z = result->at(0); - // z->printShapeInfo(); - // z->printIndexedBuffer(); - - Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0}); - ASSERT_TRUE(2 == numOfTads); - - for (int e = 0; e < numOfTads; ++e) { - NDArray tad = (*z)(e, {0}); - auto mean = tad.meanNumber().e(0); - ASSERT_NEAR((e+1)*1., mean, 1e-5); - } - - delete result; -} diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index febb65c21..2e1dab1a3 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -24,6 +24,7 @@ #include #include #include +#include using namespace nd4j; @@ -3310,6 +3311,130 @@ auto exp = NDArrayFactory::create('c', {2, 3, 3}, { // delete result; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_10) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); +// ---------------------------------------------------------------- + nd4j::ops::roll op; + auto result = op.execute({&x}, {}, {3, 1}, {}, false, nd4j::DataType::DOUBLE); + ASSERT_EQ(result->status(), Status::OK()); + auto out = result->at(0); + +// out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_11) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + auto shift = NDArrayFactory::create({1,2}); + auto axis = NDArrayFactory::create({0, 1}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 17., 18., 19., 20., 21., 22., 23., 24., 13., 14., 15., 16., 5., 6., 7, 8, 9, 10, 11, 12, 1, 2, 3, 4 + }); +// ---------------------------------------------------------------- + nd4j::ops::roll op; + NDArray* y = nullptr; + auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + ASSERT_EQ(result->status(), Status::OK()); + auto out = result->at(0); + +// out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_12) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + auto shift = NDArrayFactory::create({1,1,1}); + auto axis = NDArrayFactory::create({0, 1, 2}); + + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7 + }); +// ---------------------------------------------------------------- + nd4j::ops::roll op; + NDArray* y = nullptr; + auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + ASSERT_EQ(result->status(), Status::OK()); + auto out = result->at(0); + out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_13) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + auto shift = NDArrayFactory::create(3); + auto axis = NDArrayFactory::create(2); + + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 2,3,4,1,6,7,8,5,10,11,12,9,14, 15, 16, 13, 18, 19, 20, 17, 22, 23, 24, 21 + }); +// ---------------------------------------------------------------- + nd4j::ops::roll op; + NDArray* y = nullptr; + auto result = op.execute({&x}, {}, {3,2}, {}, false, nd4j::DataType::DOUBLE); + ASSERT_EQ(result->status(), Status::OK()); + auto out = result->at(0); + +// out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_14) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + auto shift = NDArrayFactory::create({1,1,1}); + auto axis = NDArrayFactory::create({0, 1, 2}); + + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7 + }); +// ---------------------------------------------------------------- + nd4j::ops::roll op; + + auto result = op.execute({&x, &shift, &axis}, {}, {}, {}, false, nd4j::DataType::DOUBLE); + ASSERT_EQ(result->status(), Status::OK()); + auto out = result->at(0); +// out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, percentile_test1) { @@ -3605,6 +3730,289 @@ TEST_F(DeclarableOpsTests7, transpose_test3) { delete result; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, rationaltanh_test1) { + + auto input = NDArrayFactory::create('c', {8}, {0, 1, 2, 3, 4, 5, 6, 7}); + NDArray exp = NDArrayFactory::create({0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446}); + + nd4j::ops::rationaltanh op; + auto result = op.execute({&input}, {}, {}); + auto output = result->at(0); +// output->printIndexedBuffer("Output rationaltanh"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, rationaltanh_test2) { + + auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); + NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446}); + + nd4j::ops::rationaltanh op; + auto result = op.execute({&input}, {}, {}); + auto output = result->at(0); +// output->printIndexedBuffer("Output rationaltanh"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, rationaltanh_test3) { + + auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); + auto eps = NDArrayFactory::create('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray exp = NDArrayFactory::create('c', {2,2,2}, {1.143933, 1.605747, 0.795557, 0.261710, 0.095832, 0.041218, 0.020221, 0.010971}); + + nd4j::ops::rationaltanh_bp op; + auto result = op.execute({&input, &eps}, {}, {}); + auto output = result->at(0); +// output->printBuffer("Output rationaltanh BP"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, rectifiedtanh_test1) { + + auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); + NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.761594, 0.964028, 0.995055, 0.999329, 0.999909, 0.999988, 0.999998}); + + nd4j::ops::rectifiedtanh op; + auto result = op.execute({&input}, {}, {}); + auto output = result->at(0); +// output->printIndexedBuffer("Output rectifiedtanh"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, rectifiedtanh_test2) { + + auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); + auto eps = NDArrayFactory::create('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.839949, 0.211952, 0.039464, 0.006705, 0.001089, 0.000172, 0.000027}); + + nd4j::ops::rectifiedtanh_bp op; + auto result = op.execute({&input, &eps}, {}, {}); + auto output = result->at(0); +// output->printBuffer("Output rectifiedtanh BP"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + delete result; +} + +TEST_F(DeclarableOpsTests7, RealDiv_1) { + + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); + NDArray e = NDArrayFactory::create('c', {1, 2, 2}, {2, 1, 4, 2}); + + nd4j::ops::realdiv op; + auto result = op.execute({&x, &y}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("OUtput RealDiv"); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, RealDiv_BP_1) { + + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); + NDArray e0 = NDArrayFactory::create('c', {1, 2, 1}, {2, 5}); + NDArray e1 = NDArrayFactory::create('c', {1, 2}, {-14, -5}); + NDArray eps = NDArrayFactory::create('c', {1, 2, 2}, {1, 2, 3, 4}); + + nd4j::ops::realdiv_bp op; + auto result = op.execute({&x, &y, &eps}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z0 = result->at(0); + auto z1 = result->at(1); +// z0->printShapeInfo("OUtput RealDiv BP0 shape"); +// z1->printShapeInfo("OUtput RealDiv BP1 shape"); +// z0->printIndexedBuffer("OUtput RealDiv BP0"); +// z1->printIndexedBuffer("OUtput RealDiv BP1"); +// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e0.equalsTo(z0)); + ASSERT_TRUE(e1.equalsTo(z1)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, ShapesOf_1) { + + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); +// NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); + NDArray e = NDArrayFactory::create({1, 2, 1}); + + nd4j::ops::shapes_of op; + auto result = op.execute({&x}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("OUtput RealDiv"); +// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, ShapesOf_2) { + + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); + NDArray e0 = NDArrayFactory::create({1, 2, 1}); + NDArray e1 = NDArrayFactory::create({1, 2}); + + nd4j::ops::shapes_of op; + auto result = op.execute({&x, &y}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z0 = result->at(0); + auto z1 = result->at(1); +// z0->printIndexedBuffer("OUtput shapes2"); +// z1->printIndexedBuffer("OUtput shapes2"); +// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e0.equalsTo(z0)); + ASSERT_TRUE(e1.equalsTo(z1)); + + delete result; +} + +TEST_F(DeclarableOpsTests7, Size_1) { + + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); + NDArray y = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray e = NDArrayFactory::create(2); + + nd4j::ops::size op; + auto result = op.execute({&x}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("OUtput SIZE"); +/// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + delete result; +} + +TEST_F(DeclarableOpsTests7, Size_2) { + + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); + NDArray y = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray e = NDArrayFactory::create(10); + + nd4j::ops::size op; + auto result = op.execute({&y}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("OUtput SIZE"); +/// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + delete result; +} + +TEST_F(DeclarableOpsTests7, Softplus_1) { + + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); + + nd4j::ops::softplus op; + auto result = op.execute({&x}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("OUtput Softplus"); +/// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + delete result; +} + +TEST_F(DeclarableOpsTests7, Softplus_BP_1) { + + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); +// NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); + NDArray eps = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10}); + nd4j::ops::softplus ffOP; + nd4j::ops::softplus_bp bpOp; + const OpArgsHolder argsHolderFF({&x}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); + + bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(gradOK); +// +// auto z = result->at(0); +// z->printIndexedBuffer("OUtput Softplus"); +///// ASSERT_TRUE(e.isSameShape(z)); +// ASSERT_TRUE(e.equalsTo(*z)); +// +// delete result; +} + +TEST_F(DeclarableOpsTests7, Softsign_1) { + + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray e = NDArrayFactory::create('c', {5, 2}, {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, 0.9166667}); + + nd4j::ops::softsign op; + auto result = op.execute({&x}, {}, {}); + + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printIndexedBuffer("OUtput Softsign"); +/// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + delete result; +} + +TEST_F(DeclarableOpsTests7, Softsign_BP_1) { + + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); +// NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); + NDArray eps = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10}); + nd4j::ops::softsign ffOP; + nd4j::ops::softsign_bp bpOp; + const OpArgsHolder argsHolderFF({&x}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); + + bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(gradOK); +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, fill_test2) { @@ -3644,6 +4052,185 @@ TEST_F(DeclarableOpsTests7, fill_test3) { delete result; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, ToggleBits_test1) { + + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto exp = NDArrayFactory::create('c', {2}, {-3, -3}); + + nd4j::ops::toggle_bits op; + auto result = op.execute({&x}, {}, {}, {}, false, nd4j::DataType::INT32); + auto output = result->at(0); + + ASSERT_EQ(ND4J_STATUS_OK, result->status()); +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, ToggleBits_test2) { + + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto y = NDArrayFactory::create('c', {2}, {1, 1}); + auto exp0 = NDArrayFactory::create('c', {2}, {-3, -3}); + auto exp1 = NDArrayFactory::create('c', {2}, {-2, -2}); + + nd4j::ops::toggle_bits op; + auto result = op.execute({&x, &y}, {}, {}, {}, false, nd4j::DataType::INT32); + auto output = result->at(0); + auto z = result->at(1); + + ASSERT_EQ(ND4J_STATUS_OK, result->status()); +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp0.isSameShape(output)); + ASSERT_TRUE(exp0.equalsTo(output)); + ASSERT_TRUE(exp1.isSameShape(z)); + ASSERT_TRUE(exp1.equalsTo(z)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Truncatediv_test1) { + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray y = NDArrayFactory::create('c', {5, 2}, {2,2,2,2,2,2,2,2, 2, 2}); + NDArray exp = NDArrayFactory::create('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); + + nd4j::ops::truncatediv op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto output = result->at(0); +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp.isSameShape(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Truncatediv_test2) { + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {2,2}); + NDArray exp = NDArrayFactory::create('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); + + nd4j::ops::truncatediv op; + auto result = op.execute({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto output = result->at(0); +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp.isSameShape(output)); + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TypesConversion_test1) { + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray expI = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray expL = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray expF = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray expF16 = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); + + nd4j::ops::to_int32 op32; + nd4j::ops::to_int64 op64; + auto result32 = op32.execute({&x}, {}, {}); + auto result64 = op64.execute({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result32->status()); + ASSERT_EQ(ND4J_STATUS_OK, result64->status()); + auto out1 = result32->at(0); +// out1->printIndexedBuffer("OUT_I"); + auto out2 = result64->at(0); +// out2->printIndexedBuffer("OUT_L"); + +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(expI.equalsTo(out1)); + ASSERT_TRUE(expL.equalsTo(out2)); + + delete result32; + delete result64; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TypesConversion_test2) { + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray expF = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray expH = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); + + nd4j::ops::to_float32 op32; + nd4j::ops::to_float16 op16; + auto result32 = op32.execute({&x}, {}, {}); + auto result16 = op16.execute({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result32->status()); + ASSERT_EQ(ND4J_STATUS_OK, result16->status()); + auto out1 = result32->at(0); +// out1->printIndexedBuffer("OUT_F"); + auto out2 = result16->at(0); +// out2->printIndexedBuffer("OUT_H"); + +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(expF.equalsTo(out1)); + ASSERT_TRUE(expH.equalsTo(out2)); + + delete result32; + delete result16; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TypesConversion_test3) { + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray exp32 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray exp64 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + + nd4j::ops::to_uint32 op32; + nd4j::ops::to_uint64 op64; + auto result32 = op32.execute({&x}, {}, {}); + auto result64 = op64.execute({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result32->status()); + ASSERT_EQ(ND4J_STATUS_OK, result64->status()); + auto out1 = result32->at(0); +// out1->printIndexedBuffer("OUT_U32"); + auto out2 = result64->at(0); +// out2->printIndexedBuffer("OUT_U64"); + +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp32.equalsTo(out1)); + ASSERT_TRUE(exp64.equalsTo(out2)); + + delete result32; + delete result64; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TypesConversion_test4) { + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray exp32 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray exp64 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + + nd4j::ops::to_float32 op32; + nd4j::ops::to_double op64; + auto result32 = op32.execute({&x}, {}, {}); + auto result64 = op64.execute({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result32->status()); + ASSERT_EQ(ND4J_STATUS_OK, result64->status()); + auto out1 = result32->at(0); + out1->printIndexedBuffer("OUT_F"); + auto out2 = result64->at(0); + out2->printIndexedBuffer("OUT_D"); + +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp32.equalsTo(out1)); + ASSERT_TRUE(exp64.equalsTo(out2)); + + delete result32; + delete result64; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, mirrorPad_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 0d9ea7c24..d27aa4e46 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -584,6 +584,180 @@ TEST_F(DeclarableOpsTests9, concat_test16) { delete result; } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test17) { + + NDArray x0('c', {1, 55, 40}, nd4j::DataType::DOUBLE); + NDArray x1('c', {1, 55, 40}, nd4j::DataType::DOUBLE); + + x0 = 1.; + x1 = 2.; + + nd4j::ops::concat op; + auto result = op.execute({&x0, &x1}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + // z->printShapeInfo(); + // z->printIndexedBuffer(); + + Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0}); + ASSERT_TRUE(2 == numOfTads); + + for (int e = 0; e < numOfTads; ++e) { + NDArray tad = (*z)(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((e+1)*1., mean, 1e-5); + } + + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test18) { + std::array arrays; + Context context(1); + Nd4jLong axis = 0; + + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < arrays.size(); e++) { + auto array = NDArrayFactory::create_('c', {1, 300}); + array->assign(e); + context.setInputArray(e, array, true); + } + + auto z = NDArrayFactory::create('c', {2000, 300}); + context.setOutputArray(0, &z, false); + context.setIArguments(&axis, 1); + + nd4j::ops::concat op; + op.execute(&context); + + for (int e = 0; e < arrays.size(); e++) { + auto row = z.tensorAlongDimension(e, {1}); + + ASSERT_NEAR((float) e, row->e(0), 1e-5f); + + delete row; + } +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test19) { + + std::array arrays; + Context context(1); + Nd4jLong axis = 0; + + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < arrays.size(); e++) { + auto array = NDArrayFactory::create_('c', {1, 5, 20}); + array->assign(e); + context.setInputArray(e, array, true); + } + + auto z = NDArrayFactory::create('c', {arrays.size(), 5, 20}); + context.setOutputArray(0, &z, false); + context.setIArguments(&axis, 1); + + nd4j::ops::concat op; + op.execute(&context); + + for (int e = 0; e < arrays.size(); e++) + ASSERT_NEAR((float) e, z(e, {0}).meanNumber().e(0), 1e-5f); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test20) { + auto x0 = NDArrayFactory::create('c', {1, 100, 150}); + auto x1 = NDArrayFactory::create('c', {1, 100, 150}); + auto x2 = NDArrayFactory::create('c', {1, 100, 150}); + auto x3 = NDArrayFactory::create('c', {1, 100, 150}); + + x0.assign(1.0); + x1.assign(2.0); + x2.assign(3.0); + x3.assign(4.0); + + nd4j::ops::concat op; + auto result = op.execute({&x0, &x1, &x2, &x3}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); + + Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->getShapeInfo(), {0}); + ASSERT_TRUE(4 == numOfTads); + + for (int e = 0; e < numOfTads; e++) { + NDArray tad = (*z)(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((double) e+1, mean, 1e-5); + } + + delete result; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test21) { + + NDArray x0('c', {1,4,5}, nd4j::DataType::FLOAT32); + NDArray x1('c', {2,4,5}, nd4j::DataType::FLOAT32); + NDArray z('f', {3,4,5}, nd4j::DataType::FLOAT32); + + x0 = 0.; + x1 = 1.; + + nd4j::ops::concat op; + auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test22) { + + NDArray x0('c', {1,6}, {1,2,3,4,5,6}); + NDArray x1('c', {1,6}, {7,8,9,10,11,12}); + NDArray output('f', {2,6}, nd4j::DataType::DOUBLE); + NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}); + + nd4j::ops::concat op; + + auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(exp.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test23) { + + NDArray x0('c', {1,4}, {1,2,3,4}); + NDArray x1('c', {1,4}, {5,6,7,8}); + NDArray output('c', {2,4}, nd4j::DataType::DOUBLE); + NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}); + + nd4j::ops::concat op; + + auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(exp.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test24) { + auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); + auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); + auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); + auto z = NDArrayFactory::create('c', {2, 2}); + + nd4j::ops::concat op; + auto status = op.execute({&x, &y}, {&z}, {}, {1}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests9, tile_bp_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp index b646bacab..43c6c45df 100644 --- a/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -975,69 +975,6 @@ TEST_F(JavaInteropTests, zeta_test10) { ASSERT_EQ(e, z); } -TEST_F(JavaInteropTests, Test_Is_Max_1) { - auto arrayX = NDArrayFactory::create({1, 2, 1, 1}); - auto arrayZ = NDArrayFactory::create({0, 0, 0, 0}); - auto arrayE = NDArrayFactory::create({0, 1, 0, 0}); - - nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; - #endif - - NDArray::prepareSpecialUse({&arrayZ}, {&arrayX}); - execTransformAny(extraPointers, transform::IsMax, - arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(), - arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(), - nullptr); - - NDArray::registerSpecialUse({&arrayZ}, {&arrayX}); - ASSERT_EQ(arrayE, arrayZ); - - delete []extraPointers; -} - -TEST_F(JavaInteropTests, Test_Is_Max_1_2) { - auto arrayX = NDArrayFactory::create({1, 2, 1, 1}); - auto arrayZ = NDArrayFactory::create({0, 0, 0, 0}); - auto arrayE = NDArrayFactory::create({0, 1, 0, 0}); - - nd4j::LaunchContext* context = nd4j::LaunchContext::defaultContext(); - - Nd4jPointer* extraPointers = nullptr; - #ifdef __CUDABLAS__ - extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; - #endif - - NDArray::prepareSpecialUse({&arrayZ}, {&arrayX}); - execTransformAny(extraPointers, transform::IsMax, - arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(), - arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(), - nullptr); - //arrayZ.printIndexedBuffer("JAVA ISMAX1"); - NDArray::registerSpecialUse({&arrayZ}, {&arrayX}); - ASSERT_EQ(arrayE, arrayZ); - delete []extraPointers; -} - -TEST_F(JavaInteropTests, Test_Is_Max_2) { - auto arrayX = NDArrayFactory::create('c', {3, 2, 3}, {1, 10, 2, 3, 4, 5, -10, -9, -8, -7, -6, -5, 4, 3, 2, 1, 0, -1}); - auto arrayZ = NDArrayFactory::create('c', {3, 2, 3}); - Nd4jLong tad[] = {2, 2, 3, 3, 1, 524288, -1, 99}; - Nd4jLong off[] = {0, 6, 12}; - Nd4jLong *ex[] = {tad, off}; - float ea[] = {2, 1, 2}; - - NDArray::prepareSpecialUse({&arrayZ}, {&arrayX}); - execTransformBool(reinterpret_cast(ex), transform::IsMax, - arrayX.buffer(), arrayX.shapeInfo(), arrayX.getSpecialBuffer(), arrayX.getSpecialShapeInfo(), - arrayZ.buffer(), arrayZ.shapeInfo(), arrayZ.getSpecialBuffer(), arrayZ.getSpecialShapeInfo(), - ea); - NDArray::registerSpecialUse({&arrayZ}, {&arrayX}); -} - TEST_F(JavaInteropTests, Test_IAMax_1) { auto arrayX = NDArrayFactory::create({-0.24f, -0.26f, -0.07f, -0.01f}); auto arrayZ = arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr); diff --git a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp index aeac06ccb..5308ee99d 100644 --- a/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -367,49 +367,6 @@ TEST_F(LegacyOpsTests, IndexReduceTests_2) { delete result; } -TEST_F(LegacyOpsTests, Test_IsMax_1) { - if (!Environment::getInstance()->isCPU()) - return; - - auto x = NDArrayFactory::create('c', {2, 2, 2, 2, 2, 2}); - auto z = NDArrayFactory::create('c', {2, 2, 2, 2, 2, 2}); - x.linspace(1.0); - z.assign(-589); - - double extra[] = {1.0, 0.0}; - - NativeOpExecutioner::execTransformAny(nullptr, transform::IsMax, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), extra, nullptr, nullptr); - - // z.printIndexedBuffer("z"); - for (Nd4jLong e = 0; e < z.lengthOf(); e++) { - ASSERT_TRUE(z.e(e) >= 0); - } -} - -TEST_F(LegacyOpsTests, Test_IsMax_2) { - if (!Environment::getInstance()->isCPU()) - return; - - auto x = NDArrayFactory::create('c', {2, 2, 2, 2, 2, 2}); - auto z = NDArrayFactory::create('c', {2, 2, 2, 2, 2, 2}); - x.linspace(1.0); - z.assign(false); - - double extra[] = {1.0, 0.0}; - - NativeOpExecutioner::execTransformAny(nullptr, transform::IsMax, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), - z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), extra, nullptr, nullptr); - - // z.printIndexedBuffer("z"); - for (Nd4jLong e = 0; e < z.lengthOf(); e++) { - if (e >= z.lengthOf() / 2) - ASSERT_TRUE(z.e(e)); - else - ASSERT_FALSE(z.e(e)); - } -} - TEST_F(LegacyOpsTests, BroadcastingTests_1) { auto x = NDArrayFactory::create('c', {5, 5}); x.assign(0.0f); diff --git a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp index ac778d971..e506839df 100644 --- a/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ListOperationsTests.cpp @@ -78,6 +78,72 @@ TEST_F(ListOperationsTests, BasicTest_Stack_1) { delete tads; } +TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { + NDArrayList list(0, true); + auto x = NDArrayFactory::create('c', {10, 100}); + auto tads = x.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {100}); + row->assign((double) e); + //list.write(e, row); + tads->at(e)->assign(row); + delete row; + } + + nd4j::ops::unstack_list op; + + auto result = op.execute(&list, {&x}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + ASSERT_EQ(list.elements(), 10); + +// auto z = result->at(0); +// z->printShapeInfo("The first of"); +// ASSERT_TRUE(exp.isSameShape(z)); +// ASSERT_TRUE(exp.equalsTo(z)); + for (int e = 0; e < 10; e++) { + auto row = list.read(e); + ASSERT_TRUE(row->equalsTo(tads->at(e))); + //list.write(e, row); + delete row; + } + + delete result; + delete tads; +} + +//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) { +//// NDArrayList list(0, true); +// auto x = NDArrayFactory::create('c', {10, 100}); +// auto tads = x.allTensorsAlongDimension({1}); +// for (int e = 0; e < 10; e++) { +// auto row = NDArrayFactory::create_('c', {100}); +// row->assign((double) e); +// //list.write(e, row); +// tads->at(e)->assign(row); +// delete row; +// } +// +// nd4j::ops::unstack_list op; +// +// auto result = op.execute(nullptr, {&x}, {}, {0}); +// +// ASSERT_EQ(ND4J_STATUS_OK, result->status()); +// ASSERT_EQ(result->size(), 10); +// +// // auto z = result->at(0); +//// z->printShapeInfo("The first of"); +//// ASSERT_TRUE(exp.isSameShape(z)); +//// ASSERT_TRUE(exp.equalsTo(z)); +// for (int e = 0; e < 10; e++) { +// auto row = result->at(e); +// ASSERT_TRUE(row->equalsTo(tads->at(e))); +// //list.write(e, row); +// } +// +// delete result; +// delete tads; +//} TEST_F(ListOperationsTests, BasicTest_Read_1) { NDArrayList list(10); diff --git a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp index 7560e0c9d..18b849d53 100644 --- a/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/MultiDataTypeTests.cpp @@ -193,8 +193,8 @@ TEST_F(MultiDataTypeTests, ndarray_assign_number_test3) { //////////////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, ndarray_repeat_test1) { NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, nd4j::DataType::HALF); - NDArray y('c', {2, 4}, nd4j::DataType::UINT8); - NDArray exp('c', {2, 4}, {0, 0, 1, 1, 2, 2, 3, 3}, nd4j::DataType::UINT8); + NDArray y('c', {2, 4}, nd4j::DataType::HALF); + NDArray exp('c', {2, 4}, {0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5}, nd4j::DataType::HALF); x.repeat(1, y); @@ -1790,6 +1790,7 @@ TEST_F(MultiDataTypeTests, RowCol_test2) { } ////////////////////////////////////////////////////////////////////// +/* TEST_F(MultiDataTypeTests, tile_test1) { NDArray x1('c', {2,1}, {0,1}, nd4j::DataType::INT32); @@ -1823,6 +1824,7 @@ TEST_F(MultiDataTypeTests, tile_test1) { x1.tile(x7); ASSERT_EQ(x7, exp4); } +*/ ////////////////////////////////////////////////////////////////////// TEST_F(MultiDataTypeTests, broadcast_test1) { diff --git a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp index 314a83aad..e1a23ee3f 100644 --- a/libnd4j/tests_cpu/layers_tests/RNGTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/RNGTests.cpp @@ -248,8 +248,8 @@ TEST_F(RNGTests, Test_Gaussian_21) { RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); - x0.printIndexedBuffer("x0"); - x1.printIndexedBuffer("x1"); +// x0.printIndexedBuffer("x0"); +// x1.printIndexedBuffer("x1"); ASSERT_TRUE(x0.equalsTo(&x1)); ASSERT_FALSE(x0.equalsTo(nexp0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java index 34800ca07..0ad489cbf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java @@ -229,44 +229,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm; import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndReplace; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; -import org.nd4j.linalg.api.ops.impl.transforms.custom.ATan2; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Assign; -import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace; -import org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd; -import org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D; -import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention; -import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttentionBp; -import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicPartition; -import org.nd4j.linalg.api.ops.impl.transforms.custom.DynamicStitch; -import org.nd4j.linalg.api.ops.impl.transforms.custom.EqualTo; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Fill; -import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThan; -import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; -import org.nd4j.linalg.api.ops.impl.transforms.custom.InvertPermutation; -import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNonDecreasing; -import org.nd4j.linalg.api.ops.impl.transforms.custom.IsNumericTensor; -import org.nd4j.linalg.api.ops.impl.transforms.custom.IsStrictlyIncreasing; -import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm; -import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNormBp; -import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThan; -import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; -import org.nd4j.linalg.api.ops.impl.transforms.custom.ListDiff; -import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; -import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixDeterminant; -import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse; -import org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag; -import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention; -import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttentionBp; -import org.nd4j.linalg.api.ops.impl.transforms.custom.NotEqualTo; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse; -import org.nd4j.linalg.api.ops.impl.transforms.custom.ReverseSequence; -import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; -import org.nd4j.linalg.api.ops.impl.transforms.custom.SpaceToBatch; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; -import org.nd4j.linalg.api.ops.impl.transforms.custom.StandardizeBp; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Trace; -import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB; +import org.nd4j.linalg.api.ops.impl.transforms.custom.*; import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMean; import org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentMin; @@ -289,25 +252,8 @@ import org.nd4j.linalg.api.ops.impl.transforms.gradient.SELUDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative; import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.DivOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorDivOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.FloorModOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RDivOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.RSubOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SquaredDifferenceOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.SubOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.TruncateDivOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.AddBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.DivBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp; -import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Xor; @@ -1290,7 +1236,7 @@ public class DifferentialFunctionFactory { } public SDVariable isMax(SDVariable ix) { - return new IsMax(sameDiff(), ix, false).outputVariable(); + return new IsMax(sameDiff(), ix).outputVariable(); } public SDVariable replaceWhere(SDVariable to, SDVariable from, Condition condition) { @@ -1317,6 +1263,21 @@ public class DifferentialFunctionFactory { return new Xor(sameDiff(), ix, iy).outputVariable(); } + public SDVariable shift(SDVariable ix, int shift) { + return new ShiftBits(sameDiff(), ix, shift).outputVariable(); + } + + public SDVariable rshift(SDVariable ix, int shift) { + return new RShiftBits(sameDiff(), ix, shift).outputVariable(); + } + + public SDVariable rotl(SDVariable ix, int shift) { + return new CyclicShiftBits(sameDiff(), ix, shift).outputVariable(); + } + + public SDVariable rotr(SDVariable ix, int shift) { + return new CyclicRShiftBits(sameDiff(), ix, shift).outputVariable(); + } public SDVariable eq(SDVariable iX, SDVariable i_y) { return new EqualTo(sameDiff(), new SDVariable[]{iX, i_y}, false).outputVariable(); @@ -2231,6 +2192,10 @@ public class DifferentialFunctionFactory { return Arrays.asList(new MulBpOp(sameDiff(), x, y, grad).outputVariables()); } + public List modBp(SDVariable x, SDVariable y, SDVariable grad) { + return Arrays.asList(new ModBpOp(sameDiff(), x, y, grad).outputVariables()); + } + public SDVariable muli(SDVariable differentialFunction, SDVariable i_v) { validateDifferentialFunctionsameDiff(differentialFunction); @@ -2238,6 +2203,10 @@ public class DifferentialFunctionFactory { } + public SDVariable mod(SDVariable differentialFunction, SDVariable i_v) { + validateDifferentialFunctionsameDiff(differentialFunction); + return new ModOp(sameDiff(), new SDVariable[]{differentialFunction, i_v}, false).outputVariable(); + } public SDVariable div(SDVariable differentialFunction, SDVariable i_v) { validateDifferentialFunctionsameDiff(differentialFunction); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java index f618b1186..64749da1e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/SDVariable.java @@ -804,6 +804,34 @@ public class SDVariable extends DifferentialFunction implements Serializable { return sameDiff.updateVariableNameAndReference(result, name); } + /** + * Floor division operation: elementwise {@code this // x}
+ * If this and x variables have equal shape, the output shape is the same as the inputs.
+ * Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast. + * + * @param name Name of the output variable + * @param x Variable to perform operation with + * @return Output (result) SDVariable + */ + public SDVariable fdiv(String name, SDVariable x) { + val result = sameDiff.f().floorDiv(this, x); + return sameDiff.updateVariableNameAndReference(result, name); + } + + /** + * Modulo operation: elementwise {@code this / x}
+ * If this and x variables have equal shape, the output shape is the same as the inputs.
+ * Supports broadcasting: if this and x have different shapes and are broadcastable, the output shape is broadcast. + * + * @param name Name of the output variable + * @param x Variable to perform operation with + * @return Output (result) SDVariable + */ + public SDVariable mod(String name, SDVariable x) { + val result = sameDiff.f().mod(this, x); + return sameDiff.updateVariableNameAndReference(result, name); + } + /** * See {@link #mul(String, double)} */ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index a98b03566..3d40e205a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -682,6 +682,8 @@ public class InferenceSession extends AbstractSession outShape = customOp.calculateOutputShape(); Preconditions.checkState(outShape != null && outShape.size() > 0, "Failed to calculate output shapes for op %s (%s) - no shapes were returned by calculateOutputShape()", customOp.opName(), customOp.getOwnName()); String[] outNames = df.outputVariablesNames(); + Preconditions.checkState(outNames.length == outShape.size(), "Error in operation shape calculation for op \"%s\": Got %s op output shapes for an operation" + + " with %s outputs (number of shapes and outputs must be equal)", df.opName(), outShape.size(), outNames.length); for( int i=0; i> 4 + * + * @param name Name of the output variable + * @param x Input 1 + * @return Output SDVariable with shifted bits + */ + public SDVariable bitShiftRight(String name, SDVariable x, int shift) { + validateInteger("rshift_bits", x); + SDVariable result = f().rshift(x, shift); + return updateVariableNameAndReference(result, name); + } + + /** + * Roll integer bits to the left, i.e. var << 4 | var >> (32 - 4) + * + * @param name Name of the output variable + * @param x Input 1 + * @return Output SDVariable with shifted bits + */ + public SDVariable bitRotl(String name, SDVariable x, int shift) { + validateInteger("cyclic_shift_bits", x); + SDVariable result = f().rotl(x, shift); + return updateVariableNameAndReference(result, name); + } + + /** + * Roll integer bits to the right, i.e. var >> 4 | var << (32 - 4) + * + * @param name Name of the output variable + * @param x Input 1 + * @return Output SDVariable with shifted bits + */ + public SDVariable bitRotr(String name, SDVariable x, int shift) { + validateInteger("cyclic_rshift_bits", x); + SDVariable result = f().rotr(x, shift); + return updateVariableNameAndReference(result, name); + } + /** * Full array zero fraction array reduction operation, optionally along specified dimensions: out = (count(x == 0) / length(x)) * diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java index 08528be10..769fcf109 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/evaluation/classification/EvaluationCalibration.java @@ -262,7 +262,7 @@ public class EvaluationCalibration extends BaseEvaluation labelCountsEachClass.addi(labels2d.sum(0).castTo(labelCountsEachClass.dataType())); //For prediction counts: do an IsMax op, but we need to take masking into account... - INDArray isPredictedClass = Nd4j.getExecutioner().exec(new IsMax(p.dup(), 1)); + INDArray isPredictedClass = Nd4j.getExecutioner().exec(new IsMax(p, p.ulike(), 1))[0]; if (maskArray != null) { LossUtil.applyMask(isPredictedClass, maskArray); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/AffinityManager.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/AffinityManager.java index 5625db5a5..06cffb24f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/AffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/AffinityManager.java @@ -34,6 +34,12 @@ public interface AffinityManager { */ Integer getDeviceForCurrentThread(); + /** + * This method returns deviceId for a given thread + * @return + */ + Integer getDeviceForThread(long threadId); + /** * This method returns id of current device for a given INDArray diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/BasicAffinityManager.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/BasicAffinityManager.java index ad0320825..664362308 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/BasicAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/concurrency/BasicAffinityManager.java @@ -28,6 +28,11 @@ public abstract class BasicAffinityManager implements AffinityManager { return 0; } + @Override + public Integer getDeviceForThread(long threadId) { + return 0; + } + @Override public Integer getDeviceForArray(INDArray array) { return 0; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java index 702150207..c06fe8026 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/any/IsMax.java @@ -24,6 +24,7 @@ import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseTransformAnyOp; +import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.factory.Nd4j; @@ -34,47 +35,29 @@ import java.util.List; * [1, 2, 3, 1] -> [0, 0, 1, 0] * @author Adam Gibson */ -public class IsMax extends BaseTransformAnyOp { - public IsMax(SameDiff sameDiff, SDVariable i_v, boolean inPlace) { - super(sameDiff, i_v, inPlace); +public class IsMax extends DynamicCustomOp { + public IsMax(SameDiff sameDiff, SDVariable i_v) { + super(sameDiff, i_v); } - public IsMax(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs) { - super(sameDiff, i_v, shape, inPlace, extraArgs); - } - - public IsMax(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs) { - super(sameDiff, i_v, extraArgs); - } public IsMax(INDArray x, INDArray z) { - super(x, z); + super(new INDArray[]{x}, new INDArray[]{z}); } public IsMax() {} + public IsMax(INDArray x) { - super(x, Nd4j.createUninitialized(DataType.BOOL, x.shape(), x.ordering())); + this(x, Nd4j.createUninitialized(DataType.BOOL, x.shape(), x.ordering())); } public IsMax(INDArray x, INDArray z, int... dimensions) { - super(x, z); - this.extraArgs = new Object[dimensions.length + 1]; - this.extraArgs[0] = dimensions.length; - for (int i = 0; i < dimensions.length; i++) - this.extraArgs[i + 1] = dimensions[i]; + this(x, z); + this.addIArgument(dimensions); } public IsMax(INDArray x, int... dimensions) { - super(x, Nd4j.createUninitialized(x.dataType(), x.shape(), x.ordering())); - this.extraArgs = new Object[dimensions.length + 1]; - this.extraArgs[0] = dimensions.length; - for (int i = 0; i < dimensions.length; i++) - this.extraArgs[i + 1] = dimensions[i]; - } - - @Override - public int opNum() { - return 1; + this(x, Nd4j.createUninitialized(DataType.BOOL, x.shape(), x.ordering()), dimensions); } @Override @@ -82,7 +65,6 @@ public class IsMax extends BaseTransformAnyOp { return "ismax"; } - @Override public String onnxName() { throw new NoOpNameFoundException("No onnx op opName found for " + opName()); @@ -93,14 +75,6 @@ public class IsMax extends BaseTransformAnyOp { throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); } - @Override - public DataBuffer extraArgsDataBuff(DataType dtype) { - if (Nd4j.getExecutioner().type() == OpExecutioner.ExecutionerType.CUDA) - return this.extraArgs == null ? null : Nd4j.createBuffer(DataType.LONG, 1, false); - else - return super.extraArgsDataBuff(dtype); - } - @Override public List doDiff(List f1) { return Collections.singletonList(f().zerosLike(arg())); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java index 05e898a1a..e61d488aa 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpace.java @@ -77,7 +77,7 @@ public class BatchToSpace extends DynamicCustomOp { @Override public String tensorflowName() { - return "BatchToSpaceND"; + return "BatchToSpace"; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java new file mode 100644 index 000000000..ef07c7cc6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/BatchToSpaceND.java @@ -0,0 +1,93 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + + +import lombok.val; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * N-dimensional batch to space operation. Transforms data from a tensor from batch dimension into M spatial dimensions + * according to the "blocks" specified (a vector of length M). Afterwards the spatial dimensions are optionally cropped, + * as specified in "crops", a tensor of dim (M, 2), denoting the crop range. + *

+ * Example: + * input: [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + * input shape: [4, 1, 1, 1] + * blocks: [2, 2] + * crops: [[0, 0], [0, 0]] + *

+ * output: [[[[1], [2]], [[3], [4]]]] + * output shape: [1, 2, 2, 1] + * + * @author Max Pumperla + */ +public class BatchToSpaceND extends DynamicCustomOp { + + private int[] blocks; + private int[][] crops; + + public BatchToSpaceND() { + } + + public BatchToSpaceND(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] crops, boolean inPlace) { + super(null, sameDiff, args, inPlace); + + this.blocks = blocks; + this.crops = crops; + + for (val b : blocks) + addIArgument(b); + + for (int e = 0; e < crops.length; e++) + addIArgument(crops[e][0], crops[e][1]); + } + + @Override + public String opName() { + return "batch_to_space_nd"; + } + + @Override + public String onnxName() { + return "batch_to_space_nd"; + } + + @Override + public String tensorflowName() { + return "BatchToSpaceND"; + } + + @Override + public List doDiff(List i_v) { + // Inverse of batch to space is space to batch with same blocks and padding as crops + SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); + return Arrays.asList(sameDiff.cnn().spaceToBatch(gradient, blocks, crops)); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java new file mode 100644 index 000000000..318a7dc02 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicRShiftBits.java @@ -0,0 +1,80 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Element-wise roll operation, rolls bits to the left, << + * + * @author raver119@gmail.com + */ +public class CyclicRShiftBits extends BaseDynamicTransformOp { + + public CyclicRShiftBits(SameDiff sameDiff, SDVariable x, int shift) { + super(sameDiff, new SDVariable[] {x} ,false); + this.addIArgument(shift); + } + + public CyclicRShiftBits(INDArray input, int shift, INDArray output) { + super(new INDArray[]{input}, new INDArray[]{output}); + this.addIArgument(shift); + } + + public CyclicRShiftBits(INDArray input, int shift) { + this(input, shift,null); + } + + public CyclicRShiftBits() {} + + @Override + public String opName() { + return "cyclic_rshift_bits"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java new file mode 100644 index 000000000..b4291c5df --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CyclicShiftBits.java @@ -0,0 +1,80 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Element-wise roll operation, rolls bits to the left, << + * + * @author raver119@gmail.com + */ +public class CyclicShiftBits extends BaseDynamicTransformOp { + + public CyclicShiftBits(SameDiff sameDiff, SDVariable x, int shift) { + super(sameDiff, new SDVariable[] {x} ,false); + this.addIArgument(shift); + } + + public CyclicShiftBits(INDArray input, int shift, INDArray output) { + super(new INDArray[]{input}, new INDArray[]{output}); + this.addIArgument(shift); + } + + public CyclicShiftBits(INDArray input, int shift) { + this(input, shift,null); + } + + public CyclicShiftBits() {} + + @Override + public String opName() { + return "cyclic_shift_bits"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java new file mode 100644 index 000000000..80697efa3 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/RShiftBits.java @@ -0,0 +1,80 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Element-wise shift operation, shift bits to the right, >> + * + * @author raver119@gmail.com + */ +public class RShiftBits extends BaseDynamicTransformOp { + + public RShiftBits(SameDiff sameDiff, SDVariable x, int shift) { + super(sameDiff, new SDVariable[] {x} ,false); + this.addIArgument(shift); + } + + public RShiftBits(INDArray input, int shift, INDArray output) { + super(new INDArray[]{input}, new INDArray[]{output}); + this.addIArgument(shift); + } + + public RShiftBits(INDArray input, int shift) { + this(input, shift,null); + } + + public RShiftBits() {} + + @Override + public String opName() { + return "rshift_bits"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java new file mode 100644 index 000000000..8c652f72d --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/ShiftBits.java @@ -0,0 +1,80 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.Collections; +import java.util.List; + +/** + * Element-wise shift operation, shift bits to the left, << + * + * @author raver119@gmail.com + */ +public class ShiftBits extends BaseDynamicTransformOp { + + public ShiftBits(SameDiff sameDiff, SDVariable x, int shift) { + super(sameDiff, new SDVariable[] {x} ,false); + this.addIArgument(shift); + } + + public ShiftBits(INDArray input, int shift, INDArray output) { + super(new INDArray[]{input}, new INDArray[]{output}); + this.addIArgument(shift); + } + + public ShiftBits(INDArray input, int shift) { + this(input, shift,null); + } + + public ShiftBits() {} + + @Override + public String opName() { + return "shift_bits"; + } + + + @Override + public String onnxName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No onnx op opName found for " + opName()); + } + + + @Override + public List doDiff(List i_v) { + throw new UnsupportedOperationException("Not yet implemented: " + opName()); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0)); + return Collections.singletonList(dataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java index 8ad1936ca..12fe52854 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatch.java @@ -77,7 +77,7 @@ public class SpaceToBatch extends DynamicCustomOp { @Override public String tensorflowName() { - return "SpaceToBatchND"; + return "SpaceToBatch"; } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java new file mode 100644 index 000000000..9eb72e54f --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/SpaceToBatchND.java @@ -0,0 +1,95 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.custom; + + +import lombok.val; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * N-dimensional space to batch operation. Transforms data from a tensor from M spatial dimensions into batch dimension + * according to the "blocks" specified (a vector of length M). Afterwards the spatial dimensions are optionally padded, + * as specified in "padding", a tensor of dim (M, 2), denoting the padding range. + *

+ * Example: + * input: [[[[1], [2]], [[3], [4]]]] + * input shape: [1, 2, 2, 1] + * blocks: [2, 2] + * padding: [[0, 0], [0, 0]] + *

+ * output: [[[[1]]], [[[2]]], [[[3]]], [[[4]]]] + * output shape: [4, 1, 1, 1] + * * + * + * @author Max Pumperla + */ +public class SpaceToBatchND extends DynamicCustomOp { + + protected int[] blocks; + protected int[][] padding; + + public SpaceToBatchND() { + } + + public SpaceToBatchND(SameDiff sameDiff, SDVariable[] args, int[] blocks, int[][] padding, boolean inPlace) { + super(null, sameDiff, args, inPlace); + + this.blocks = blocks; + this.padding = padding; + + for (val b : blocks) + addIArgument(b); + + for (int e = 0; e < padding.length; e++) + addIArgument(padding[e][0], padding[e][1]); + } + + @Override + public String opName() { + return "space_to_batch_nd"; + } + + @Override + public String onnxName() { + return "space_to_batch_nd"; + } + + @Override + public String tensorflowName() { + return "SpaceToBatchND"; + } + + @Override + public List doDiff(List i_v) { + // Inverse of space to batch is batch to space with same blocks and crops as padding + SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); + return Arrays.asList(sameDiff.cnn().batchToSpace(gradient, blocks, padding)); + } + + @Override + public List calculateOutputDataTypes(List dataTypes){ + return Collections.singletonList(dataTypes.get(0)); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java new file mode 100644 index 000000000..289333f96 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/ModOp.java @@ -0,0 +1,69 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; + +import java.util.List; + +/** + * Modulo operation + * + * @author raver119@gmail.com + */ +public class ModOp extends BaseDynamicTransformOp { + public static final String OP_NAME = "mod"; + + public ModOp() {} + + public ModOp(SameDiff sameDiff, SDVariable[] args, boolean inPlace) { + super(sameDiff, args, inPlace); + } + + public ModOp(INDArray first, INDArray second, INDArray result){ + this(new INDArray[]{first, second}, result == null ? null : new INDArray[]{result}); + } + + public ModOp(INDArray[] inputs, INDArray[] outputs) { + super(inputs, outputs); + } + + @Override + public String opName() { + return OP_NAME; + } + + @Override + public String onnxName() { + return "Mod"; + } + + @Override + public String tensorflowName() { + return "mod"; + } + + @Override + public List doDiff(List i_v) { + return f().modBp(larg(), rarg(), i_v.get(0)); + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/ModBpOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/ModBpOp.java new file mode 100644 index 000000000..a9c401dd7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/ModBpOp.java @@ -0,0 +1,39 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp; + +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; + +/** + * Modulo backprop operation. Supports 'undoing' of auto broadcast as applied in div op forward pass + * + * @author raver119@gmail.com + */ +public class ModBpOp extends BaseArithmeticBackpropOp { + + public ModBpOp() {} + + public ModBpOp(SameDiff sameDiff, SDVariable x, SDVariable y, SDVariable eps) { + super(sameDiff, x,y,eps); + } + + @Override + public String opName() { + return "mod_bp"; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java index 442dd0f5f..bbe133dbb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/shape/Shape.java @@ -3676,7 +3676,7 @@ public class Shape { } public static boolean isR(@NonNull DataType x) { - return x == DataType.FLOAT || x == DataType.HALF || x == DataType.DOUBLE; + return x == DataType.FLOAT || x == DataType.HALF || x == DataType.DOUBLE || x == DataType.BFLOAT16; } private static DataType max(@NonNull DataType typeX, @NonNull DataType typeY) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index 429782c3e..1f55de3dc 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -1,4 +1,4 @@ -/******************************************************************************* +/* ***************************************************************************** * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the @@ -16,7 +16,6 @@ package org.nd4j.linalg.factory; -import com.google.common.base.Function; import com.google.common.primitives.Ints; import com.google.common.primitives.Longs; import lombok.NonNull; @@ -85,6 +84,7 @@ import org.nd4j.linalg.memory.deallocation.DeallocatorService; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.string.NDArrayStrings; import org.nd4j.linalg.util.ArrayUtil; +import org.nd4j.linalg.util.LongUtils; import org.nd4j.tools.PropertyParser; import org.nd4j.versioncheck.VersionCheck; @@ -1756,14 +1756,12 @@ public class Nd4j { index[j] = (double) j; } - /** + /* * Inject a comparator that sorts indices relative to * the actual values in the data. * This allows us to retain the indices * and how they were rearranged. */ - - Arrays.sort(index, new Comparator() { @Override public int compare(Double o1, Double o2) { @@ -2155,6 +2153,7 @@ public class Nd4j { * Defaults to scientific notation with 18 digits after the decimal * Use {@link #writeTxt(INDArray, String)} */ + @SuppressWarnings("unused") //backward compatibility. public static void writeTxt(INDArray write, String filePath, String split, int precision) { writeTxt(write,filePath); } @@ -2169,6 +2168,7 @@ public class Nd4j { * Defaults to scientific notation with 18 digits after the decimal * Use {@link #writeTxt(INDArray, String)} */ + @SuppressWarnings("unused") //backward compatibility. public static void writeTxt(INDArray write, String filePath, int precision) { writeTxt(write, filePath); } @@ -2182,6 +2182,7 @@ public class Nd4j { * @deprecated custom col and higher dimension separators are no longer supported; uses "," * Use {@link #writeTxt(INDArray, String)} */ + @SuppressWarnings("unused") public static void writeTxt(INDArray write, String filePath, String split) { writeTxt(write,filePath); } @@ -2206,15 +2207,13 @@ public class Nd4j { write = write.dup(); String format = "0.000000000000000000E0"; - return new StringBuilder() - .append("{\n") - .append("\"filefrom\": \"dl4j\",\n") - .append( "\"ordering\": \"").append(write.ordering()).append("\",\n") - .append("\"shape\":\t").append( java.util.Arrays.toString(write.shape())).append(",\n") - .append("\"data\":\n") - .append(new NDArrayStrings(",", format).format(write, false)) - .append("\n}\n") - .toString(); + return "{\n" + + "\"filefrom\": \"dl4j\",\n" + + "\"ordering\": \"" + write.ordering() + "\",\n" + + "\"shape\":\t" + Arrays.toString(write.shape()) + ",\n" + + "\"data\":\n" + + new NDArrayStrings(",", format).format(write, false) + + "\n}\n"; } @@ -2242,8 +2241,7 @@ public class Nd4j { ByteArrayOutputStream bos = new ByteArrayOutputStream((int) (arr.length() * arr.data().getElementSize())); DataOutputStream dos = new DataOutputStream(bos); write(arr, dos); - byte[] ret = bos.toByteArray(); - return ret; + return bos.toByteArray(); } /** @@ -2251,10 +2249,9 @@ public class Nd4j { * @param arr the array to read from * @return the deserialized ndarray */ - public static INDArray fromByteArray(@NonNull byte[] arr) throws IOException { + public static INDArray fromByteArray(@NonNull byte[] arr) { ByteArrayInputStream bis = new ByteArrayInputStream(arr); - INDArray ret = read(bis); - return ret; + return read(bis); } /** @@ -2277,6 +2274,7 @@ public class Nd4j { * @param charset the charset * @return the deserialized array. */ + @SuppressWarnings("WeakerAccess") //really should add testing for the method. public static INDArray readNumpy(@NonNull DataType dataType, @NonNull InputStream filePath, @NonNull String split, @NonNull Charset charset) throws IOException { BufferedReader reader = new BufferedReader(new InputStreamReader(filePath, charset)); String line; @@ -2369,7 +2367,7 @@ public class Nd4j { * * See {@link #read(DataInputStream)} */ - public static INDArray read(InputStream reader) throws IOException { + public static INDArray read(InputStream reader) { return read(new DataInputStream(reader)); } @@ -2379,6 +2377,7 @@ public class Nd4j { * @param ndarray the input stream ndarray * @return NDArray */ + @SuppressWarnings("WeakerAccess") public static INDArray readTxtString(InputStream ndarray) { String sep = ","; /* @@ -2447,6 +2446,7 @@ public class Nd4j { String[] entries = line.replace("\\],", "").replaceAll("]", "").replaceAll("\\[", "").split(sep); if (rank == 0) { try { + //noinspection ConstantConditions newArr.addi((format.parse(entries[0])).doubleValue()); } catch (ParseException e) { e.printStackTrace(); @@ -2558,17 +2558,17 @@ public class Nd4j { public static INDArray read(DataInputStream dis) { val headerShape = BaseDataBuffer.readHeader(dis); + //noinspection UnnecessaryUnboxing var shapeInformation = Nd4j.createBufferDetached(new long[]{headerShape.getMiddle().longValue()}, headerShape.getRight()); shapeInformation.read(dis, headerShape.getLeft(), headerShape.getMiddle(), headerShape.getThird()); - val length = Shape.length(shapeInformation); - DataType type = null; + DataType type; DataBuffer data = null; val headerData = BaseDataBuffer.readHeader(dis); try { // current version contains dtype in extras data = CompressedDataBuffer.readUnknown(dis, headerData.getFirst(), headerData.getMiddle(), headerData.getRight()); - type = ArrayOptionsHelper.dataType(shapeInformation.asLong()); + ArrayOptionsHelper.dataType(shapeInformation.asLong()); } catch (ND4JUnknownDataTypeException e) { // manually setting data type type = headerData.getRight(); @@ -2731,16 +2731,7 @@ public class Nd4j { public static INDArray choice(INDArray source, INDArray probs, INDArray target) { return choice(source, probs, target, Nd4j.getRandom()); } - - /** - * - * - * @param source - * @param probs - * @param numSamples - * @return - */ - + // @see tag works well here. /** * This method returns new INDArray instance, sampled from Source array with probabilities given in Probs. @@ -3749,14 +3740,13 @@ public class Nd4j { * This method creates new 0D INDArray, aka scalar. * * PLEASE NOTE: Temporary method, added to ensure backward compatibility - * @param scalar - * @return - * @deprecated Use Nd4j.scalar methods, such as {@link #scalar(double)} or {@link #scalar(DataType, Number)} + * @param scalar data for INDArray. + * @return new INDArray + * * @deprecated Use Nd4j.scalar methods, such as {@link #scalar(double)} or {@link #scalar(DataType, Number)} */ @Deprecated public static INDArray trueScalar(Number scalar) { - val ret = INSTANCE.trueScalar(scalar); - return ret; + return INSTANCE.trueScalar(scalar); } /** @@ -3764,8 +3754,7 @@ public class Nd4j { */ @Deprecated public static INDArray trueVector(boolean[] data) { - val ret = INSTANCE.trueVector(data); - return ret; + return INSTANCE.trueVector(data); } /** @@ -3773,8 +3762,7 @@ public class Nd4j { */ @Deprecated public static INDArray trueVector(long[] data) { - val ret = INSTANCE.trueVector(data); - return ret; + return INSTANCE.trueVector(data); } /** @@ -3782,8 +3770,7 @@ public class Nd4j { */ @Deprecated public static INDArray trueVector(int[] data) { - val ret = INSTANCE.trueVector(data); - return ret; + return INSTANCE.trueVector(data); } /** @@ -3791,8 +3778,7 @@ public class Nd4j { */ @Deprecated public static INDArray trueVector(float[] data) { - val ret = INSTANCE.trueVector(data); - return ret; + return INSTANCE.trueVector(data); } /** @@ -3800,8 +3786,7 @@ public class Nd4j { */ @Deprecated public static INDArray trueVector(double[] data) { - val ret = INSTANCE.trueVector(data); - return ret; + return INSTANCE.trueVector(data); } /** @@ -3820,7 +3805,7 @@ public class Nd4j { */ public static INDArray empty(DataType type) { if(EMPTY_ARRAYS[type.ordinal()] == null){ - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()){ + try(MemoryWorkspace ignored = Nd4j.getMemoryManager().scopeOutOfWorkspaces()){ val ret = INSTANCE.empty(type); EMPTY_ARRAYS[type.ordinal()] = ret; } @@ -3844,11 +3829,8 @@ public class Nd4j { if (shape[0] != data.length) throw new ND4JIllegalStateException("Shape of the new array doesn't match data length"); } - - checkShapeValues(data.length, shape); - - INDArray ret = INSTANCE.create(data, shape); - return ret; + checkShapeValues(data.length, LongUtils.toLongs(shape)); + return INSTANCE.create(data, shape); } /** @@ -3858,16 +3840,8 @@ public class Nd4j { if (shape.length == 0 && data.length == 1) { return scalar(data[0]); } - - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array doesn't match data length"); - } - - checkShapeValues(data.length, shape); - - INDArray ret = INSTANCE.create(data, shape, Nd4j.getStrides(shape, Nd4j.order()), DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace()); - return ret; + commonCheckCreate(data.length, shape); + return INSTANCE.create(data, shape, Nd4j.getStrides(shape, Nd4j.order()), DataType.FLOAT, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** @@ -3877,16 +3851,8 @@ public class Nd4j { if (shape.length == 0 && data.length == 1) { return scalar(data[0]); } - - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array doesn't match data length"); - } - - checkShapeValues(data.length, shape); - - INDArray ret = INSTANCE.create(data, shape, Nd4j.getStrides(shape, Nd4j.order()), DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); - return ret; + commonCheckCreate(data.length, shape); + return INSTANCE.create(data, shape, Nd4j.getStrides(shape, Nd4j.order()), DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** @@ -3897,16 +3863,9 @@ public class Nd4j { * @return the created ndarray */ public static INDArray create(double[] data, int... shape) { - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + data.length); - } - - checkShapeValues(data.length, shape); - + commonCheckCreate(data.length, LongUtils.toLongs(shape)); val lshape = ArrayUtil.toLongArray(shape); - INDArray ret = INSTANCE.create(data, lshape, Nd4j.getStrides(lshape, Nd4j.order()), DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); - return ret; + return INSTANCE.create(data, lshape, Nd4j.getStrides(lshape, Nd4j.order()), DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** @@ -3920,32 +3879,26 @@ public class Nd4j { * @return the created ndarray. */ public static INDArray create(double[] data, int[] shape, long offset, char ordering) { - if (shape.length == 1) { - if (shape[0] != data.length) + commonCheckCreate(data.length, LongUtils.toLongs(shape)); + return INSTANCE.create(data, shape, Nd4j.getStrides(shape, ordering), offset, ordering); + } + + private static void commonCheckCreate( int dataLength, long[] shape){ + if (shape.length== 1) { + if (shape[0] != dataLength) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + data.length); + + " doesn't match data length: " + dataLength); } - checkShapeValues(data.length, shape); - - INDArray ret = INSTANCE.create(data, shape, Nd4j.getStrides(shape, ordering), offset, ordering); - return ret; + checkShapeValues(dataLength, shape); } /** * See {@link #create(double[], int[], long, char )} */ public static INDArray create(double[] data, long[] shape, long offset, char ordering) { - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + data.length); - } - - checkShapeValues(data.length, shape); - - INDArray ret = INSTANCE.create(data, shape, Nd4j.getStrides(shape, ordering), offset, ordering); - return ret; + commonCheckCreate(data.length, shape); + return INSTANCE.create(data, shape, Nd4j.getStrides(shape, ordering), offset, ordering); } /** @@ -3958,16 +3911,8 @@ public class Nd4j { * @return the instance */ public static INDArray create(float[] data, int[] shape, int[] stride, long offset) { - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + data.length); - } - - checkShapeValues(data.length, shape); - - INDArray ret = INSTANCE.create(data, shape, stride, offset); - return ret; + commonCheckCreate(data.length, LongUtils.toLongs(shape)); + return INSTANCE.create(data, shape, stride, offset); } /** @@ -3979,9 +3924,7 @@ public class Nd4j { */ public static INDArray create(List list, int... shape) { checkShapeValues(shape); - - INDArray ret = INSTANCE.create(list, shape); - return ret; + return INSTANCE.create(list, shape); } /** @@ -3989,9 +3932,7 @@ public class Nd4j { */ public static INDArray create(List list, long... shape) { checkShapeValues(shape); - - INDArray ret = INSTANCE.create(list, shape); - return ret; + return INSTANCE.create(list, shape); } /** @@ -4027,10 +3968,7 @@ public class Nd4j { */ public static INDArray create(int[] shape, int[] stride, long offset) { checkShapeValues(shape); - - INDArray ret = INSTANCE.create(shape, stride, offset); - return ret; - + return INSTANCE.create(shape, stride, offset); } /** @@ -4260,34 +4198,21 @@ public class Nd4j { /** * Create an array withgiven shape and ordering based on a java double array. * @param data java array used for initialisation. Must have at least the number of elements required. - * @@param shape desired shape of new array. + * @param shape desired shape of new array. * @param ordering Fortran 'f' or C/C++ 'c' ordering. * @return the created ndarray. */ public static INDArray create(double[] data, int[] shape, char ordering) { - //TODO: duplicate code and issue #8013 - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + data.length); - } - - checkShapeValues(data.length, shape); + commonCheckCreate(data.length, LongUtils.toLongs(shape)); val lshape = ArrayUtil.toLongArray(shape); - return INSTANCE.create(data, lshape, Nd4j.getStrides(lshape, ordering), ordering, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); + return INSTANCE.create(data, lshape, Nd4j.getStrides(lshape, ordering), ordering, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); } /** * See {@link #create(double[], int[], char)} */ public static INDArray create(float[] data, int[] shape, char ordering) { - if (shape.length == 1) { - if (shape[0] != data.length) - throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + data.length); - } - - checkShapeValues(data.length, shape); + commonCheckCreate(data.length, LongUtils.toLongs(shape)); return INSTANCE.create(data, shape, ordering); } @@ -4307,22 +4232,6 @@ public class Nd4j { return INSTANCE.create(data, shape, Nd4j.getStrides(shape, ordering), ordering, DataType.DOUBLE, Nd4j.getMemoryManager().getCurrentWorkspace()); } - /** - * Creates an ndarray with the specified shape - * TODO: unused method. (only used by the zeros method in this class) - * - * @param rows the rows of the ndarray - * @param columns the columns of the ndarray - * @param stride the stride for the ndarray - * @param offset the offset of the ndarray - * @return the instance - */ - public static INDArray create(int rows, int columns, int[] stride, long offset, char ordering) { - int[] shape = new int[]{rows, columns}; - checkShapeValues(shape); - return INSTANCE.create(shape, stride, offset, ordering); - } - /** * Creates an ndarray with the specified shape * @@ -4469,11 +4378,9 @@ public class Nd4j { return INSTANCE.create(dataType, shape, ordering, Nd4j.getMemoryManager().getCurrentWorkspace()); } - - // TODO: Leaving these until #8028 is fixed. /** - * - * @param shape + * Throws exception on negative shape values. + * @param shape to check */ public static void checkShapeValues(long... shape) { for (long e: shape) { @@ -4483,30 +4390,13 @@ public class Nd4j { } } - // TODO: Leaving these until #8028 is fixed. - /** - * - * @param shape - */ - public static void checkShapeValues(int... shape) { - for (int e: shape) { - if (e < 0) - throw new ND4JIllegalStateException("Invalid shape: Requested INDArray shape " + Arrays.toString(shape) - + " contains dimension size values < 0 (all dimensions must be 0 or more)"); - } + // made private as it is only used for internal checks. + private static void checkShapeValues(int... shape) { + checkShapeValues(LongUtils.toLongs(shape)); } - protected static void checkShapeValues(int length, int... shape) { + private static void checkShapeValues(int length, long... shape) { checkShapeValues(shape); - - if (ArrayUtil.prodLong(shape) != length && !(length == 1 && shape.length == 0)) - throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) - + " doesn't match data length: " + length + " - prod(shape) must equal the number of values provided"); - } - - protected static void checkShapeValues(int length, long... shape) { - checkShapeValues(shape); - if (ArrayUtil.prodLong(shape) != length && !(length == 1 && shape.length == 0)) throw new ND4JIllegalStateException("Shape of the new array " + Arrays.toString(shape) + " doesn't match data length: " + length + " - prod(shape) must equal the number of values provided"); @@ -4604,8 +4494,8 @@ public class Nd4j { * * PLEASE NOTE: Do not use this method unless you're 100% sure why you use it. * - * @param length - * @return + * @param length length of array to create + * @return the created INDArray */ public static INDArray createUninitialized(long length) { long[] shape = new long[] {length}; @@ -4619,6 +4509,7 @@ public class Nd4j { * @param shape the shape of the array. * @return the created detached array. */ + @SuppressWarnings("WeakerAccess") // For now. If part of public API it will need testing. public static INDArray createUninitializedDetached(DataType dataType, char ordering, long... shape){ return INSTANCE.createUninitializedDetached(dataType, ordering, shape); } @@ -4738,6 +4629,7 @@ public class Nd4j { * @param type data type * @return the created ndarray */ + @SuppressWarnings("Duplicates") public static INDArray valueArrayOf(long[] shape, double value, DataType type) { if (shape.length == 0) return scalar(type, value); @@ -4752,6 +4644,7 @@ public class Nd4j { /** * See {@link #valueArrayOf(long[], double, DataType)} */ + @SuppressWarnings("Duplicates") public static INDArray valueArrayOf(long[] shape, long value, DataType type) { if (shape.length == 0) return scalar(type, value); @@ -4774,8 +4667,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray valueArrayOf(long num, double value) { - INDArray ret = INSTANCE.valueArrayOf(new long[] {num}, value); - return ret; + return INSTANCE.valueArrayOf(new long[] {num}, value); } /** @@ -4789,8 +4681,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray valueArrayOf(long rows, long columns, double value) { - INDArray ret = INSTANCE.valueArrayOf(rows, columns, value); - return ret; + return INSTANCE.valueArrayOf(rows, columns, value); } /** @@ -4801,8 +4692,7 @@ public class Nd4j { * @return the created ndarray */ public static INDArray ones(int rows, int columns) { - INDArray ret = INSTANCE.ones(rows, columns); - return ret; + return INSTANCE.ones(rows, columns); } /** @@ -4860,8 +4750,7 @@ public class Nd4j { * @param arrs the first matrix to concat */ public static INDArray hstack(@NonNull INDArray... arrs) { - INDArray ret = INSTANCE.hstack(arrs); - return ret; + return INSTANCE.hstack(arrs); } /** @@ -4883,6 +4772,7 @@ public class Nd4j { */ public static INDArray vstack(@NonNull INDArray... arrs) { Preconditions.checkState(arrs != null && arrs.length > 0, "No input specified to vstack (null or length 0)"); + //noinspection ConstantConditions if(arrs[0].rank() == 1){ //Edge case: vstack rank 1 arrays - gives rank 2... vstack([3],[3]) -> [2,3] return pile(arrs); @@ -4901,24 +4791,12 @@ public class Nd4j { return vstack(arrays); } - // TODO: unused method /** * This method averages input arrays, and returns averaged array. * On top of that, averaged array is propagated to all input arrays * - * @param arrays - * @return - */ - public static INDArray averageAndPropagate(INDArray target, INDArray[] arrays) { - return INSTANCE.average(target, arrays); - } - - /** - * This method averages input arrays, and returns averaged array. - * On top of that, averaged array is propagated to all input arrays - * - * @param arrays - * @return + * @param arrays arrays to average + * @return averaged arrays */ public static INDArray averageAndPropagate(INDArray[] arrays) { return INSTANCE.average(arrays); @@ -4929,8 +4807,8 @@ public class Nd4j { * This method averages input arrays, and returns averaged array. * On top of that, averaged array is propagated to all input arrays * - * @param arrays - * @return + * @param arrays arrays to average + * @return averaged arrays */ public static INDArray averageAndPropagate(Collection arrays) { return INSTANCE.average(arrays); @@ -4940,20 +4818,19 @@ public class Nd4j { * This method averages input arrays, and returns averaged array. * On top of that, averaged array is propagated to all input arrays * - * @param arrays - * @return + * @param arrays arrays to average + * @return averaged arrays */ public static INDArray averageAndPropagate(INDArray target, Collection arrays) { return INSTANCE.average(target, arrays); } - - /** * Reshapes an ndarray to remove leading 1s * @param toStrip the ndarray to newShapeNoCopy * @return the reshaped ndarray */ + @SuppressWarnings("WeakerAccess") // Needs tests if part of public API. public static INDArray stripOnes(INDArray toStrip) { if (toStrip.isVector()) return toStrip; @@ -4966,8 +4843,8 @@ public class Nd4j { /** * This method sums given arrays and stores them to a new array * - * @param arrays - * @return + * @param arrays array to accumulate + * @return accumulated array. */ public static INDArray accumulate(@NonNull INDArray... arrays) { if (arrays == null|| arrays.length == 0) @@ -4979,9 +4856,9 @@ public class Nd4j { /** * This method sums given arrays and stores them to a given target array * - * @param target - * @param arrays - * @return + * @param target result array + * @param arrays arrays to sum + * @return result array */ public static INDArray accumulate(INDArray target, Collection arrays) { return accumulate(target, arrays.toArray(new INDArray[0])); @@ -4990,14 +4867,13 @@ public class Nd4j { /** * This method sums given arrays and stores them to a given target array * - * @param target - * @param arrays - * @return + * @param target result array + * @param arrays arrays to sum + * @return result array */ public static INDArray accumulate(INDArray target, INDArray[] arrays) { if (arrays == null|| arrays.length == 0) return target; - return factory().accumulate(target, arrays); } @@ -5007,7 +4883,7 @@ public class Nd4j { * @param source source tensor * @param sourceDimension dimension of source tensor * @param indexes indexes from source array - * @return + * @return result array */ public static INDArray pullRows(INDArray source, int sourceDimension, @NonNull int... indexes) { return pullRows(source, sourceDimension, indexes, Nd4j.order()); @@ -5022,8 +4898,9 @@ public class Nd4j { * @param source source tensor * @param sourceDimension dimension of source tensor * @param indexes indexes from source array - * @return + * @return concatenated array */ + @SuppressWarnings("Duplicates") public static INDArray pullRows(INDArray source, int sourceDimension, int[] indexes, char order) { if (sourceDimension >= source.rank()) throw new IllegalStateException("Source dimension can't be higher the rank of source tensor"); @@ -5042,9 +4919,7 @@ public class Nd4j { } Preconditions.checkArgument(source.rank() > 1, "pullRows() can't operate on 0D/1D arrays"); - - INDArray ret = INSTANCE.pullRows(source, sourceDimension, indexes, order); - return ret; + return INSTANCE.pullRows(source, sourceDimension, indexes, order); } /** @@ -5058,6 +4933,7 @@ public class Nd4j { * @param indexes indexes from source array * @return Destination array with specified tensors */ + @SuppressWarnings("Duplicates") public static INDArray pullRows(INDArray source, INDArray destination, int sourceDimension, @NonNull int... indexes){ if (sourceDimension >= source.rank()) throw new IllegalStateException("Source dimension can't be higher the rank of source tensor"); @@ -5074,8 +4950,7 @@ public class Nd4j { Preconditions.checkArgument(source.rank() > 1, "pullRows() can't operate on 0D/1D arrays"); - INDArray ret = INSTANCE.pullRows(source, destination, sourceDimension, indexes); - return ret; + return INSTANCE.pullRows(source, destination, sourceDimension, indexes); } /** @@ -5091,8 +4966,9 @@ public class Nd4j { * @return Output array * @see #concat(int, INDArray...) */ + @SuppressWarnings("ConstantConditions") public static INDArray stack(int axis, @NonNull INDArray... values){ - Preconditions.checkArgument(values != null && values.length > 0, "No inputs: %s", values); + Preconditions.checkArgument(values != null && values.length > 0, "No inputs: %s", (Object[]) values); Preconditions.checkState(axis >= -(values[0].rank()+1) && axis < values[0].rank()+1, "Invalid axis: must be between " + "%s (inclusive) and %s (exclusive) for rank %s input, got %s", -(values[0].rank()+1), values[0].rank()+1, values[0].rank(), axis); @@ -5126,13 +5002,12 @@ public class Nd4j { * * PLEASE NOTE: This method is special for GPU backend, it works on HOST side only. * - * @param dimension - * @param toConcat - * @return + * @param dimension dimension + * @param toConcat arrayts to concatenate + * @return concatenated arrays. */ public static INDArray specialConcat(int dimension, @NonNull INDArray... toConcat) { - INDArray ret = INSTANCE.specialConcat(dimension, toConcat); - return ret; + return INSTANCE.specialConcat(dimension, toConcat); } /** @@ -5143,9 +5018,7 @@ public class Nd4j { */ public static INDArray zeros(int[] shape, char order) { checkShapeValues(shape); - - INDArray ret = INSTANCE.create(shape, order); - return ret; + return INSTANCE.create(shape, order); } /** @@ -5184,9 +5057,7 @@ public class Nd4j { * @return an ndarray with ones filled in */ public static INDArray ones(@NonNull int... shape) { - if(shape.length == 0) - return Nd4j.scalar(dataType(), 1.0); - return INSTANCE.ones(shape); + return (shape.length == 0) ? Nd4j.scalar(dataType(), 1.0) : INSTANCE.ones(shape); } @@ -5216,6 +5087,7 @@ public class Nd4j { * @param value the value to initialize the scalar with * @return the created ndarray */ + @SuppressWarnings("deprecation") public static INDArray scalar(DataType dataType, Number value) { return INSTANCE.trueScalar(dataType, value); } @@ -5247,7 +5119,7 @@ public class Nd4j { * @return the scalar nd array */ public static INDArray scalar(boolean value) { - return INSTANCE.trueScalar(DataType.BOOL, value ? 1 : 0); + return scalar(DataType.BOOL, value ? 1 : 0); } /** @@ -5324,45 +5196,10 @@ public class Nd4j { return Nd4j.exec(new Tile(new INDArray[]{tile}, new INDArray[]{}, repeat))[0]; } - /** - * Get the strides for the given order and shape - * - * @param shape the shape of the array - * @param order the order to getScalar the strides for - * @return the strides for the given shape and order - */ - public static int[] getComplexStrides(int[] shape, char order) { - if (order == NDArrayFactory.FORTRAN) - return ArrayUtil.calcStridesFortran(shape, 2); - return ArrayUtil.calcStrides(shape, 2); - } - - public static long[] getComplexStrides(long[] shape, char order) { - if (order == NDArrayFactory.FORTRAN) - return ArrayUtil.calcStridesFortran(shape, 2); - return ArrayUtil.calcStrides(shape, 2); - } - - /** - * Get the strides based on the shape - * and NDArrays.order() - * - * @param shape the shape of the array - * @return the strides for the given shape - * and order specified by NDArrays.order() - */ - public static int[] getComplexStrides(@NonNull int... shape) { - return getComplexStrides(shape, Nd4j.order()); - } - - public static long[] getComplexStrides(@NonNull long... shape) { - return getComplexStrides(shape, Nd4j.order()); - } - /** * Initializes nd4j */ - public synchronized void initContext() { + private synchronized void initContext() { try { defaultFloatingPointDataType = new AtomicReference<>(); defaultFloatingPointDataType.set(DataType.FLOAT); @@ -5377,6 +5214,7 @@ public class Nd4j { * Initialize with the specific backend * @param backend the backend to initialize with */ + @SuppressWarnings({"unchecked", "Duplicates"}) public void initWithBackend(Nd4jBackend backend) { VersionCheck.checkVersions(); @@ -5557,7 +5395,7 @@ public class Nd4j { /** * - * @return + * @return Shape info provider */ public static ShapeInfoProvider getShapeInfoProvider() { return shapeInfoProvider; @@ -5565,7 +5403,7 @@ public class Nd4j { /** * - * @return + * @return Sparse shape info provider */ public static SparseInfoProvider getSparseInfoProvider() { return sparseInfoProvider; @@ -5573,7 +5411,7 @@ public class Nd4j { /** * - * @return + * @return constant handler */ public static ConstantHandler getConstantHandler() { return constantHandler; @@ -5581,7 +5419,7 @@ public class Nd4j { /** * - * @return + * @return affinity manager */ public static AffinityManager getAffinityManager() { return affinityManager; @@ -5589,7 +5427,7 @@ public class Nd4j { /** * - * @return + * @return NDArrayFactory */ public static NDArrayFactory getNDArrayFactory() { return INSTANCE; @@ -5600,7 +5438,7 @@ public class Nd4j { * suitable for NDArray compression/decompression * at runtime * - * @return + * @return BasicNDArrayCompressor instance */ public static BasicNDArrayCompressor getCompressor() { return BasicNDArrayCompressor.getInstance(); @@ -5608,16 +5446,12 @@ public class Nd4j { /** * This method returns backend-specific MemoryManager implementation, for low-level memory management - * @return + * @return MemoryManager */ public static MemoryManager getMemoryManager() { return memoryManager; } - public static INDArray typeConversion(INDArray array, DataTypeEx targetType) { - return null; - } - /** * This method returns sizeOf(currentDataType), in bytes * @@ -5633,7 +5467,7 @@ public class Nd4j { * This method returns size of element for specified dataType, in bytes * * @param dtype number of bytes per element - * @return + * @return element size */ public static int sizeOfDataType(DataType dtype) { switch (dtype) { @@ -5666,7 +5500,7 @@ public class Nd4j { * * PLEASE NOTE: Do not use this method, unless you have too. * - * @param reallyEnable + * @param reallyEnable fallback mode */ public static void enableFallbackMode(boolean reallyEnable) { fallbackMode.set(reallyEnable); @@ -5675,8 +5509,9 @@ public class Nd4j { /** * This method checks, if fallback mode was enabled. * - * @return + * @return fallback mode */ + @SuppressWarnings("BooleanMethodIsAlwaysInverted") public static boolean isFallbackModeEnabled() { return fallbackMode.get(); } @@ -5684,31 +5519,29 @@ public class Nd4j { /** * This method returns WorkspaceManager implementation to be used within this JVM process * - * @return + * @return WorkspaceManager */ public static MemoryWorkspaceManager getWorkspaceManager() { return workspaceManager; } /** - * This method stacks vertically examples with the same shape, increasing result dimensionality. I.e. if you provide bunch of 3D tensors, output will be 4D tensor. Alignment is always applied to axis 0. + * This method stacks vertically examples with the same shape, increasing result dimensionality. + * I.e. if you provide bunch of 3D tensors, output will be 4D tensor. Alignment is always applied to axis 0. * - * @return + * @param arrays arrays to stack + * @return stacked arrays */ public static INDArray pile(@NonNull INDArray... arrays) { // if we have vectors as input, it's just vstack use case long[] shape = arrays[0].shape(); + //noinspection deprecation long[] newShape = ArrayUtils.add(shape, 0, 1); - boolean shouldReshape = true; - List reshaped = new ArrayList<>(); for(INDArray array: arrays) { - if (!shouldReshape) - reshaped.add(array); - else - reshaped.add(array.reshape(array.ordering(), newShape)); + reshaped.add(array.reshape(array.ordering(), newShape)); } return Nd4j.vstack(reshaped); @@ -5717,7 +5550,8 @@ public class Nd4j { /** * This method stacks vertically examples with the same shape, increasing result dimensionality. I.e. if you provide bunch of 3D tensors, output will be 4D tensor. Alignment is always applied to axis 0. * - * @return + * @param arrays arrays to stack + * @return stacked array */ public static INDArray pile(@NonNull Collection arrays) { return pile(arrays.toArray(new INDArray[0])); @@ -5726,22 +5560,20 @@ public class Nd4j { /** * This method does the opposite to pile/vstack/hstack - it returns independent TAD copies along given dimensions * - * @param tensor - * @param dimensions - * @return + * @param tensor Array to tear + * @param dimensions dimensions + * @return Array copies */ public static INDArray[] tear(INDArray tensor, @NonNull int... dimensions) { if (dimensions.length >= tensor.rank()) throw new ND4JIllegalStateException("Target dimensions number should be less tensor rank"); - for (int e = 0; e < dimensions.length; e++) - if (dimensions[e] < 0) - throw new ND4JIllegalStateException("Target dimensions can't have negative values"); + for (int dimension : dimensions) + if (dimension < 0) throw new ND4JIllegalStateException("Target dimensions can't have negative values"); return factory().tear(tensor, dimensions); } - /** * Upper triangle of an array. @@ -5750,50 +5582,34 @@ public class Nd4j { Please refer to the documentation for `tril` for further details. - * @param m - * @param k - * @return + See Also + -------- + tril : lower triangle of an array + + Examples + -------- + >>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) + array([[ 1, 2, 3], + [ 4, 5, 6], + [ 0, 8, 9], + [ 0, 0, 12]]) + + """ + m = asanyarray(m) + mask = tri(*m.shape[-2:], k=k-1, dtype=bool) + + return where(mask, zeros(1, m.dtype), m) + + * @param m source array + * @param k to zero below the k-th diagonal + * @return copy with elements below the `k`-th diagonal zeroed. */ public static INDArray triu(INDArray m,int k) { - /** - * """ - Upper triangle of an array. - Return a copy of a matrix with the elements below the `k`-th diagonal - zeroed. - - Please refer to the documentation for `tril` for further details. - - See Also - -------- - tril : lower triangle of an array - - Examples - -------- - >>> np.triu([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], -1) - array([[ 1, 2, 3], - [ 4, 5, 6], - [ 0, 8, 9], - [ 0, 0, 12]]) - - """ - m = asanyarray(m) - mask = tri(*m.shape[-2:], k=k-1, dtype=bool) - - return where(mask, zeros(1, m.dtype), m) - */ - - //INDArray mask = tri(m.size(-2),1); - /** + /* * Find a way to apply choose with an existing condition array. * (This appears to be the select op in libnd4j) */ - /* - Select select = new Select(new INDArray[]{mask,Nd4j.zeros(1),m},new INDArray[]{Nd4j.zerosLike(m)}); - Nd4j.getExecutioner().exec(select); - return select.getOutputArgument(0); - */ - INDArray result = Nd4j.createUninitialized(m.shape()); val op = DynamicCustomOp.builder("triu") @@ -5803,25 +5619,18 @@ public class Nd4j { .build(); Nd4j.getExecutioner().execAndReturn(op); - return result; } - /** - * - * @param n - * @return + * See {@link #tri(int,int,int)} with m = n, k=0. */ public static INDArray tri(int n) { return tri(n,n,0); } /** - * - * @param n - * @param k - * @return + * See {@link #tri(int,int,int)} with m = n. */ public static INDArray tri(int n,int k) { return tri(n,n,k); @@ -5836,24 +5645,16 @@ public class Nd4j { * @param k The sub-diagonal at and below which the array is filled. `k` = 0 is the main diagonal, while `k` < 0 is below it, and `k` > 0 is above. The default is 0. - * @return + * @return array with ones at and below the given diagonal and zeros elsewhere */ public static INDArray tri(int n,int m,int k) { - /* - INDArray mRet = Transforms.greaterThanOrEqual(arange(n),arange(-k,m - k)); - - return mRet; - */ - INDArray ret = Nd4j.createUninitialized(n, m); - val op = DynamicCustomOp.builder("tri") .addIntegerArguments(n, m, k) .addOutputs(ret) .build(); Nd4j.getExecutioner().execAndReturn(op); - return ret; } @@ -5872,7 +5673,6 @@ public class Nd4j { public static INDArray[] where(INDArray condition, INDArray x, INDArray y){ Preconditions.checkState((x == null && y == null) || (x != null && y != null), "Both X and Y must be" + "null, or neither must be null"); - INDArray out; DynamicCustomOp.DynamicCustomOpsBuilder op = DynamicCustomOp.builder("where_np"); List outShapes; if(x == null){ @@ -5880,6 +5680,7 @@ public class Nd4j { op.addInputs(condition); } else { if(!x.equalShapes(y) || !x.equalShapes(condition)){ + //noinspection ConstantConditions Preconditions.throwStateEx("Shapes must be equal: condition=%s, x=%s, y=%s", condition.shape(), x.shape(), y.shape()); } op.addInputs(condition, x, y); @@ -5912,6 +5713,7 @@ public class Nd4j { * @param file the file to write to * @throws IOException if an error occurs when writing the file */ + @SuppressWarnings("WeakerAccess") public static void writeAsNumpy(INDArray arr, File file) throws IOException { writeAsNumpy(arr, new FileOutputStream(file)); } @@ -5922,6 +5724,7 @@ public class Nd4j { * @param arr the array to convert * @return a pointer to the numpy struct */ + @SuppressWarnings("WeakerAccess") public static Pointer convertToNumpy(INDArray arr) { return INSTANCE.convertToNumpy(arr); } @@ -5931,8 +5734,8 @@ public class Nd4j { * Writes an array to an output stream * @param arr the array to write * @param writeTo the output stream to write to - * @throws IOException */ + @SuppressWarnings("WeakerAccess") public static void writeAsNumpy(INDArray arr, OutputStream writeTo) throws IOException { try(BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(writeTo)) { Pointer asNumpy = convertToNumpy(arr); @@ -5945,7 +5748,6 @@ public class Nd4j { bufferedOutputStream.flush(); } - } @@ -5958,6 +5760,7 @@ public class Nd4j { * numpy pointer */ + @SuppressWarnings("WeakerAccess") public static INDArray createFromNpyPointer(Pointer pointer) { return INSTANCE.createFromNpyPointer(pointer); } @@ -5983,8 +5786,8 @@ public class Nd4j { * Create a numpy array based on the passed in input stream * @param is the input stream to read * @return the loaded ndarray - * @throws IOException */ + @SuppressWarnings("unused") public static INDArray createNpyFromInputStream(InputStream is) throws IOException { byte[] content = IOUtils.toByteArray(is); return createNpyFromByteArray(content); @@ -6013,7 +5816,6 @@ public class Nd4j { * @return the {@link INDArray} as a byte array * with the numpy format. * For more on the format, see: https://docs.scipy.org/doc/numpy-1.14.0/neps/npy-format.html - * @throws IOException */ public static byte[] toNpyByteArray(INDArray input) { try { @@ -6106,7 +5908,7 @@ public class Nd4j { val bytes = new byte[prod]; val sb = bb.order(_order).asReadOnlyBuffer(); for (int e = 0; e < prod; e++) - bytes[e] = (byte) sb.get(e + sb.position()); + bytes[e] = sb.get(e + sb.position()); return Nd4j.create(bytes, shapeOf, stridesOf, ordering, DataType.BYTE); } @@ -6150,7 +5952,7 @@ public class Nd4j { /** * This method returns maximal allowed number of threads for Nd4j. * If value wasn't set in advance, max(1, availableProcessor) will be returned - * @return + * @return maximal allowed number of threads */ public static int numThreads() { val v = numThreads.get(); @@ -6162,7 +5964,7 @@ public class Nd4j { /** * This method sets maximal allowed number of threads for Nd4j - * @param numthreads + * @param numthreads maximal allowed number of threads */ public static void setNumThreads(int numthreads) { numThreads.set(numthreads); @@ -6178,6 +5980,7 @@ public class Nd4j { public static INDArray scalar(@NonNull String string) { + //noinspection RedundantArrayCreation return create(Collections.singletonList(string), new long[0]); } @@ -6197,7 +6000,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with DOUBLE data type */ public static INDArray createFromArray(double... array) { @@ -6210,7 +6013,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with FLOAT data type */ public static INDArray createFromArray(float... array) { @@ -6223,7 +6026,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with INT32 data type */ public static INDArray createFromArray(int... array) { @@ -6236,7 +6039,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with INT16 data type */ public static INDArray createFromArray(short... array) { @@ -6249,7 +6052,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with INT8 data type */ public static INDArray createFromArray(byte... array) { @@ -6262,7 +6065,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with INT64 data type */ public static INDArray createFromArray(long... array) { @@ -6275,7 +6078,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with BOOL data type */ public static INDArray createFromArray(boolean... array) { @@ -6290,7 +6093,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with DOUBLE data type */ public static INDArray createFromArray(double[][] array) { @@ -6304,7 +6107,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with FLOAT data type */ public static INDArray createFromArray(float[][] array) { @@ -6318,7 +6121,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with INT64 data type */ public static INDArray createFromArray(long[][] array) { @@ -6332,7 +6135,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with INT32 data type */ public static INDArray createFromArray(int[][] array) { @@ -6346,7 +6149,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with INT16 data type */ public static INDArray createFromArray(short[][] array) { @@ -6360,7 +6163,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with INT8 data type */ public static INDArray createFromArray(byte[][] array) { @@ -6374,7 +6177,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with BOOL data type */ public static INDArray createFromArray(boolean[][] array) { @@ -6391,7 +6194,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with DOUBLE data type */ public static INDArray createFromArray(double[][][] array) { @@ -6405,7 +6208,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with FLOAT data type */ public static INDArray createFromArray(float[][][] array) { @@ -6419,7 +6222,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with INT64 data type */ public static INDArray createFromArray(long[][][] array) { @@ -6434,7 +6237,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with INT32 data type */ public static INDArray createFromArray(int[][][] array) { @@ -6449,7 +6252,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with INT16 data type */ public static INDArray createFromArray(short[][][] array) { @@ -6463,7 +6266,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with INT8 data type */ public static INDArray createFromArray(byte[][][] array) { @@ -6477,7 +6280,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with BOOL data type */ public static INDArray createFromArray(boolean[][][] array) { @@ -6493,7 +6296,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with DOUBLE data type */ public static INDArray createFromArray(double[][][][] array) { @@ -6507,7 +6310,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with FLOAT data type */ public static INDArray createFromArray(float[][][][] array) { @@ -6521,7 +6324,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with INT64 data type */ public static INDArray createFromArray(long[][][][] array) { @@ -6535,7 +6338,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with INT32 data type */ public static INDArray createFromArray(int[][][][] array) { @@ -6549,7 +6352,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with INT16 data type */ public static INDArray createFromArray(short[][][][] array) { @@ -6563,7 +6366,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with INT8 data type */ public static INDArray createFromArray(byte[][][][] array) { @@ -6577,7 +6380,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with BOOL data type */ public static INDArray createFromArray(boolean[][][][] array) { @@ -6589,7 +6392,6 @@ public class Nd4j { return create(ArrayUtil.flatten(array), shape, ArrayUtil.calcStrides(shape), 'c', DataType.BOOL); } - public static synchronized DeallocatorService getDeallocatorService() { if (deallocatorService == null) deallocatorService = new DeallocatorService(); @@ -6601,7 +6403,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with DOUBLE data type */ public static INDArray createFromArray(Double[] array) { @@ -6610,7 +6412,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with FLOAT data type */ public static INDArray createFromArray(Float[] array) { @@ -6619,7 +6421,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with INT32 data type */ public static INDArray createFromArray(Integer[] array) { @@ -6628,7 +6430,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with INT16 data type */ public static INDArray createFromArray(Short[] array) { @@ -6637,7 +6439,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with INT8 data type */ public static INDArray createFromArray(Byte[] array) { @@ -6646,7 +6448,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with INT64 data type */ public static INDArray createFromArray(Long[] array) { @@ -6655,7 +6457,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 1D INDArray with BOOL data type */ public static INDArray createFromArray(Boolean[] array) { @@ -6666,7 +6468,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with DOUBLE data type */ public static INDArray createFromArray(Double[][] array) { @@ -6675,7 +6477,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with FLOAT data type */ public static INDArray createFromArray(Float[][] array) { @@ -6684,7 +6486,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with INT32 data type */ public static INDArray createFromArray(Integer[][] array) { @@ -6693,7 +6495,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with INT16 data type */ public static INDArray createFromArray(Short[][] array) { @@ -6702,7 +6504,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with INT8 data type */ public static INDArray createFromArray(Byte[][] array) { @@ -6711,7 +6513,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with INT64 data type */ public static INDArray createFromArray(Long[][] array) { @@ -6720,7 +6522,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 2D INDArray with BOOL data type */ public static INDArray createFromArray(Boolean[][] array) { @@ -6731,7 +6533,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with DOUBLE data type */ public static INDArray createFromArray(Double[][][] array) { @@ -6740,7 +6542,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with FLOAT data type */ public static INDArray createFromArray(Float[][][] array) { @@ -6749,7 +6551,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with INT32 data type */ public static INDArray createFromArray(Integer[][][] array) { @@ -6758,7 +6560,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with INT16 data type */ public static INDArray createFromArray(Short[][][] array) { @@ -6767,7 +6569,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with INT8 data type */ public static INDArray createFromArray(Byte[][][] array) { @@ -6776,7 +6578,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with INT64 data type */ public static INDArray createFromArray(Long[][][] array) { @@ -6785,7 +6587,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 3D INDArray with BOOL data type */ public static INDArray createFromArray(Boolean[][][] array) { @@ -6796,7 +6598,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with DOUBLE data type */ public static INDArray createFromArray(Double[][][][] array) { @@ -6805,7 +6607,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with FLOAT data type */ public static INDArray createFromArray(Float[][][][] array) { @@ -6814,7 +6616,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with INT32 data type */ public static INDArray createFromArray(Integer[][][][] array) { @@ -6823,7 +6625,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with INT16 data type */ public static INDArray createFromArray(Short[][][][] array) { @@ -6832,7 +6634,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with INT8 data type */ public static INDArray createFromArray(Byte[][][][] array) { @@ -6841,7 +6643,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with INT64 data type */ public static INDArray createFromArray(Long[][][][] array) { @@ -6850,7 +6652,7 @@ public class Nd4j { /** * This method creates INDArray from provided jvm array - * @param array + * @param array jvm array * @return 4D INDArray with BOOL data type */ public static INDArray createFromArray(Boolean[][][][] array) { @@ -6891,12 +6693,6 @@ public class Nd4j { /** * This method applies ScatterUpdate op - * - * @param op - * @param array - * @param indices - * @param updates - * @param axis */ @Deprecated public static void scatterUpdate(ScatterUpdate.UpdateOp op, @NonNull INDArray array, @NonNull INDArray indices, @NonNull INDArray updates, int... axis) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java index 57660b8d7..de997ec82 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java @@ -378,7 +378,7 @@ public class Transforms { public static INDArray asin(INDArray in, boolean copy) { - return Nd4j.getExecutioner().exec(new ASin(((copy ? in.dup() : in)))); + return Nd4j.getExecutioner().exec(new ASin(in, (copy ? in.ulike() : in))); } public static INDArray atan(INDArray arr) { @@ -999,7 +999,8 @@ public class Transforms { } public static INDArray isMax(INDArray input, INDArray output) { - return Nd4j.getExecutioner().exec(new IsMax(input, output)); + Nd4j.getExecutioner().exec(new IsMax(input, output)); + return output; } @@ -1035,7 +1036,7 @@ public class Transforms { * @return */ public static INDArray sqrt(INDArray ndArray, boolean dup) { - return exec(dup ? new Sqrt(ndArray, ndArray.ulike()) : new Sqrt(ndArray)); + return exec(dup ? new Sqrt(ndArray, ndArray.ulike()) : new Sqrt(ndArray, ndArray)); } /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java index 77c709487..30bd2dd93 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/CudaAffinityManager.java @@ -68,7 +68,24 @@ public class CudaAffinityManager extends BasicAffinityManager { */ @Override public Integer getDeviceForCurrentThread() { - return NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice(); + val id = NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice(); + if (!affinityMap.containsKey(Thread.currentThread().getId())) + affinityMap.put(Thread.currentThread().getId(), id); + + return id; + } + + /** + * This method returns deviceId for a given thread + * @return + */ + @Override + public Integer getDeviceForThread(long threadId) { + val id = affinityMap.get(threadId); + if (id == null) + throw new RuntimeException("Affinity for thread [" + threadId + "] wasn't defined yet"); + + return id; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 26f228430..1a964bef5 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -1308,40 +1308,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { val hostXShapeInfo = op.x() == null ? null : AddressRetriever.retrieveHostPointer(op.x().shapeInfoDataBuffer()); var hostYShapeInfo = op.y() == null ? null : AddressRetriever.retrieveHostPointer(op.y().shapeInfoDataBuffer()); - // IsMax - if (op.getOpType() == Op.Type.TRANSFORM_ANY && op.opNum() == 1 && op.extraArgs() != null && op.extraArgs().length > 0) { - // for IsMax along dimension we need special temporary buffer - dimension = new int[(int) op.extraArgs()[0]]; - - for (int i = 0; i < dimension.length; i++) { - dimension[i] = (int) op.extraArgs()[i + 1]; - } - - - for (int i = 0; i < dimension.length; i++) { - if (dimension[i] < 0) - dimension[i] += op.x().rank(); - } - //do op along all dimensions - if (dimension.length == op.x().rank()) - dimension = new int[] {Integer.MAX_VALUE}; - - long[] retShape = Shape.wholeArrayDimension(dimension) ? new long[] {} - : ArrayUtil.removeIndex(op.x().shape(), dimension); - - ret = Nd4j.createUninitialized(DataType.LONG, retShape); - - // FIXME: this maybe misleading use of this particular pointer - hostYShapeInfo = allocator.getPointer(ret.shapeInfoDataBuffer(), context); - retHostShape = allocator.getHostPointer(ret.shapeInfoDataBuffer()); - - //dimensionPointer = AtomicAllocator.getInstance().getPointer(Nd4j.createBuffer(dimension), context); - DataBuffer dimensionBuffer = allocator.getConstantBuffer(dimension); - dimensionDevPointer = allocator.getPointer(dimensionBuffer, context); - dimensionHostPointer = allocator.getHostPointer(dimensionBuffer); - - retPointer = allocator.getPointer(ret, context); - } if (op.z() == null) { ret = Nd4j.createUninitialized(op.resultType(), op.x().shape(), op.x().ordering()); @@ -1365,37 +1331,6 @@ public class CudaExecutioner extends DefaultOpExecutioner { op.validateDataTypes(experimentalMode.get()); - // SoftMax, LogSoftMax, SoftMaxDerivative - if (op.getOpType() == Op.Type.TRANSFORM_STRICT && (op.opNum() >= 0 && op.opNum() <= 2)) { - tadBuffers = tadManager.getTADOnlyShapeInfo(op.x(), new int[] {0}); - tadMaxBuffers = tadManager.getTADOnlyShapeInfo(op.x().rank() == 1 ? op.x().reshape(1, -1) : op.x(), new int[] {1}); - - hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst()); - devTadShapeInfo = allocator.getPointer(tadBuffers.getFirst(), context); - - hostMaxTadShapeInfo = AddressRetriever.retrieveHostPointer(tadMaxBuffers.getFirst()); - devMaxTadShapeInfo = allocator.getPointer(tadMaxBuffers.getFirst(), context); - - DataBuffer offsets = tadBuffers.getSecond(); - devTadOffsets = offsets == null ? null : allocator.getPointer(offsets, context); - - DataBuffer maxOffsets = tadMaxBuffers.getSecond(); - devMaxTadOffsets = maxOffsets == null ? null : allocator.getPointer(maxOffsets, context); - } else if (op.getOpType() == Op.Type.TRANSFORM_ANY && op.opNum() == 1) { // IsMax - tadBuffers = tadManager.getTADOnlyShapeInfo(op.z(), dimension); - - hostTadShapeInfo = AddressRetriever.retrieveHostPointer(tadBuffers.getFirst()); - devTadShapeInfo = AtomicAllocator.getInstance().getPointer(tadBuffers.getFirst(), context); - - DataBuffer offsets = tadBuffers.getSecond(); - devTadOffsets = offsets == null ? null : allocator.getPointer(offsets, context); - - if (retPointer == null) - retPointer = context.getBufferReduction(); - } - - - Pointer z = allocator.getPointer(op.z(), context); Pointer zShapeInfo = allocator.getPointer(op.z().shapeInfoDataBuffer(), context); @@ -1462,7 +1397,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { case TRANSFORM_FLOAT: nativeOps.execTransformFloat(xShapeInfoHostPointer, op.opNum(), null, (LongPointer) hostXShapeInfo, x, (LongPointer) xShapeInfo, - op.z().data().addressPointer(), (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, + null, (LongPointer) hostZShapeInfo, z, (LongPointer) zShapeInfo, extraArgs); break; case TRANSFORM_BOOL: diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 52fe5c652..e0d53a66f 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -11131,7 +11131,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // REGISTER_C(NAME) // nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { // auto shapeList = SHAPELIST(); -// for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { +// auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); +// for (int e = 0; e < opLimit; e++) { // auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); // shapeList->push_back(newshape); // } @@ -11168,7 +11169,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // REGISTER_C(NAME) // nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { // auto shapeList = SHAPELIST(); -// for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { +// auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); +// for (int e = 0; e < opLimit; e++) { // Nd4jLong* newshape; // COPY_SHAPE(inputShape->at(0), newshape); // shapeList->push_back(CONSTANT(newshape)); @@ -11191,7 +11193,8 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); // REGISTER_C(NAME) // nd4j::ShapeList* nd4j::ops::NAME::calculateOutputShape(nd4j::ShapeList* inputShape, nd4j::graph::Context& block) { // auto shapeList = SHAPELIST(); -// for (int e = 0; e < this->getOpDescriptor()->getNumberOfOutputs(); e++) { +// auto opLimit = this->getOpDescriptor()->getNumberOfOutputs() < 1 ? block.width() : this->getOpDescriptor()->getNumberOfOutputs(); +// for (int e = 0; e < opLimit; e++) { // auto newshape = ConstantShapeHelper::getInstance()->createShapeInfo(ArrayOptions::dataType(inputShape->at(e)), shape::order(inputShape->at(e)), shape::rank(inputShape->at(e)), shape::shapeOf(inputShape->at(e))); // shapeList->push_back(newshape); // } @@ -16282,9 +16285,15 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif - + /* + * Complete tensor with max indices merged from all input tensors list + * + * INPUT: tensors with the same shape + * OUTPUT: integer tensor with the same shape + * INT_ARG: result type (one of int), INT32 by default + */ // #if NOT_EXCLUDED(OP_mergemaxindex) - @Namespace("nd4j::ops") public static class mergemaxindex extends DeclarableOp { + @Namespace("nd4j::ops") public static class mergemaxindex extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public mergemaxindex(Pointer p) { super(p); } @@ -16295,10 +16304,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); return (mergemaxindex)super.position(position); } - public mergemaxindex() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + public mergemaxindex() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif // #if NOT_EXCLUDED(OP_mergeadd) @@ -21722,7 +21731,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); /** * This operation toggles individual bits of each element in array * - * PLEASE NOTE: This operation is possible only on integer datatypes + * PLEASE NOTE: This operation is possible only on integer data types * * \tparam T */ @@ -21743,6 +21752,107 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); } // #endif + + + /** + * This operation shift individual bits of each element in array to the left: << + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_shift_bits) + @Namespace("nd4j::ops") public static class shift_bits extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public shift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public shift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public shift_bits position(long position) { + return (shift_bits)super.position(position); + } + + public shift_bits() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + /** + * This operation shift individual bits of each element in array to the right: >> + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_rshift_bits) + @Namespace("nd4j::ops") public static class rshift_bits extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public rshift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public rshift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public rshift_bits position(long position) { + return (rshift_bits)super.position(position); + } + + public rshift_bits() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + /** + * This operation shift individual bits of each element in array, shifting to the left + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_cyclic_shift_bits) + @Namespace("nd4j::ops") public static class cyclic_shift_bits extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cyclic_shift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cyclic_shift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cyclic_shift_bits position(long position) { + return (cyclic_shift_bits)super.position(position); + } + + public cyclic_shift_bits() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + /** + * This operation shift individual bits of each element in array, shifting to the right + * + * PLEASE NOTE: This operation is applicable only to integer data types + * + * \tparam T + */ +// #if NOT_EXCLUDED(OP_cyclic_rshift_bits) + @Namespace("nd4j::ops") public static class cyclic_rshift_bits extends DeclarableOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public cyclic_rshift_bits(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public cyclic_rshift_bits(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public cyclic_rshift_bits position(long position) { + return (cyclic_rshift_bits)super.position(position); + } + + public cyclic_rshift_bits() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif @@ -22494,7 +22604,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ // #if NOT_EXCLUDED(OP_to_double) - @Namespace("nd4j::ops") public static class to_double extends DeclarableOp { + @Namespace("nd4j::ops") public static class to_double extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public to_double(Pointer p) { super(p); } @@ -22505,10 +22615,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); return (to_double)super.position(position); } - public to_double() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + public to_double() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** @@ -22517,7 +22627,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ // #if NOT_EXCLUDED(OP_to_float16) - @Namespace("nd4j::ops") public static class to_float16 extends DeclarableOp { + @Namespace("nd4j::ops") public static class to_float16 extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public to_float16(Pointer p) { super(p); } @@ -22528,10 +22638,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); return (to_float16)super.position(position); } - public to_float16() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + public to_float16() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** @@ -22540,7 +22650,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ // #if NOT_EXCLUDED(OP_to_float32) - @Namespace("nd4j::ops") public static class to_float32 extends DeclarableOp { + @Namespace("nd4j::ops") public static class to_float32 extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public to_float32(Pointer p) { super(p); } @@ -22551,10 +22661,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); return (to_float32)super.position(position); } - public to_float32() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + public to_float32() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** @@ -22563,7 +22673,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ // #if NOT_EXCLUDED(OP_to_int32) - @Namespace("nd4j::ops") public static class to_int32 extends DeclarableOp { + @Namespace("nd4j::ops") public static class to_int32 extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public to_int32(Pointer p) { super(p); } @@ -22574,10 +22684,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); return (to_int32)super.position(position); } - public to_int32() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + public to_int32() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** @@ -22586,7 +22696,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ // #if NOT_EXCLUDED(OP_to_int64) - @Namespace("nd4j::ops") public static class to_int64 extends DeclarableOp { + @Namespace("nd4j::ops") public static class to_int64 extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public to_int64(Pointer p) { super(p); } @@ -22597,10 +22707,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); return (to_int64)super.position(position); } - public to_int64() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + public to_int64() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** @@ -22609,7 +22719,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ // #if NOT_EXCLUDED(OP_to_uint32) - @Namespace("nd4j::ops") public static class to_uint32 extends DeclarableOp { + @Namespace("nd4j::ops") public static class to_uint32 extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public to_uint32(Pointer p) { super(p); } @@ -22620,10 +22730,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); return (to_uint32)super.position(position); } - public to_uint32() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + public to_uint32() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** @@ -22632,7 +22742,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); * PLEASE NOTE: This op is disabled atm, and reserved for future releases. */ // #if NOT_EXCLUDED(OP_to_uint64) - @Namespace("nd4j::ops") public static class to_uint64 extends DeclarableOp { + @Namespace("nd4j::ops") public static class to_uint64 extends DeclarableCustomOp { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public to_uint64(Pointer p) { super(p); } @@ -22643,10 +22753,10 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); return (to_uint64)super.position(position); } - public to_uint64() { super((Pointer)null); allocate(); } - private native void allocate(); - public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); - } + public to_uint64() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } // #endif /** diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java index 167e490a8..f4984679a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/SameDiffTests.java @@ -1516,7 +1516,7 @@ public class SameDiffTests extends BaseNd4jTest { //then dL/dIn = 1 if in_i == min(in) or 0 otherwise //Note that we don't have an "IsMin" op, so use IsMax(neg(in)) which is equivalent - INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.neg())).castTo(Nd4j.defaultFloatingPointType()); + INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.neg()))[0].castTo(Nd4j.defaultFloatingPointType()); assertEquals(exp, dLdIn); } @@ -1540,7 +1540,7 @@ public class SameDiffTests extends BaseNd4jTest { //If L = max(in) //then dL/dIn = 1 if in_i == max(in) or 0 otherwise - INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.dup())).castTo(DataType.DOUBLE); + INDArray exp = Nd4j.getExecutioner().exec(new IsMax(arr.dup()))[0].castTo(DataType.DOUBLE); assertEquals(exp, dLdIn); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 638cd8ac3..230fa1337 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -72,6 +72,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform; import org.nd4j.linalg.api.ops.impl.transforms.comparison.CompareAndSet; import org.nd4j.linalg.api.ops.impl.transforms.comparison.Eps; import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace; +import org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryRelativeError; import org.nd4j.linalg.api.ops.impl.transforms.pairwise.Set; @@ -261,7 +262,7 @@ public class Nd4jTestsC extends BaseNd4jTest { public void testIsMaxVectorCase() { INDArray arr = Nd4j.create(new double[] {1, 2, 4, 3}, new long[] {2, 2}); INDArray assertion = Nd4j.create(new boolean[] {false, false, true, false}, new long[] {2, 2}, DataType.BOOL); - INDArray test = Nd4j.getExecutioner().exec(new IsMax(arr)); + INDArray test = Nd4j.getExecutioner().exec(new IsMax(arr))[0]; assertEquals(assertion, test); } @@ -719,7 +720,7 @@ public class Nd4jTestsC extends BaseNd4jTest { //Tests: full buffer... //1d INDArray arr1 = Nd4j.create(new double[] {1, 2, 3, 1}); - val res1 = Nd4j.getExecutioner().exec(new IsMax(arr1)); + val res1 = Nd4j.getExecutioner().exec(new IsMax(arr1))[0]; INDArray exp1 = Nd4j.create(new boolean[] {false, false, true, false}); assertEquals(exp1, res1); @@ -736,8 +737,8 @@ public class Nd4jTestsC extends BaseNd4jTest { INDArray exp2d = Nd4j.create(new boolean[][] {{false, false, false}, {false, true, false}}); INDArray f = arr2d.dup('f'); - INDArray out2dc = Nd4j.getExecutioner().exec(new IsMax(arr2d.dup('c'))); - INDArray out2df = Nd4j.getExecutioner().exec(new IsMax(arr2d.dup('f'))); + INDArray out2dc = Nd4j.getExecutioner().exec(new IsMax(arr2d.dup('c')))[0]; + INDArray out2df = Nd4j.getExecutioner().exec(new IsMax(arr2d.dup('f')))[0]; assertEquals(exp2d, out2dc); assertEquals(exp2d, out2df); } @@ -803,16 +804,48 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testIsMaxEqualValues_2() { //[0 2] [0 1] - //[2 1] -> [0 0] - INDArray orig = Nd4j.create(new double[][] {{0, 2}, {2, 1}}); + //[2 1] -> [0 0]bg + INDArray orig = Nd4j.create(new double[][] {{0, 3}, {2, 1}}); INDArray exp = Nd4j.create(new double[][] {{0, 1}, {0, 0}}); INDArray outc = Transforms.isMax(orig.dup('c')); assertEquals(exp, outc); - INDArray outf = Transforms.isMax(orig.dup('f')); + log.info("Orig: {}", orig.dup('f').data().asFloat()); + + INDArray outf = Transforms.isMax(orig.dup('f'), orig.dup('f').ulike()); + log.info("OutF: {}", outf.data().asFloat()); assertEquals(exp, outf); } + @Test + public void testIsMaxEqualValues_3() { + //[0 2] [0 1] + //[2 1] -> [0 0] + INDArray orig = Nd4j.create(new double[][] {{0, 2}, {3, 1}}); + INDArray exp = Nd4j.create(new double[][] {{0, 0}, {1, 0}}); + INDArray outc = Transforms.isMax(orig.dup('c')); + assertEquals(exp, outc); + + INDArray outf = Transforms.isMax(orig.dup('f'), orig.dup('f').ulike()); + assertEquals(exp, outf); + } + + @Test + public void testSqrt_1() { + val x = Nd4j.createFromArray(9.0, 9.0, 9.0, 9.0); + val x2 = Nd4j.createFromArray(9.0, 9.0, 9.0, 9.0); + val e = Nd4j.createFromArray(3.0, 3.0, 3.0, 3.0); + + val z1 = Transforms.sqrt(x, true); + val z2 = Transforms.sqrt(x2, false); + + + assertEquals(e, z2); + assertEquals(e, x2); + assertEquals(e, z1); + + } + @Test public void testAssign_CF() { val orig = Nd4j.create(new double[][] {{0, 2}, {2, 1}}); @@ -828,8 +861,8 @@ public class Nd4jTestsC extends BaseNd4jTest { //1d: row vector INDArray orig = Nd4j.create(new double[] {1, 2, 3, 1}).reshape(1,4 ); - INDArray alongDim0 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 0)); - INDArray alongDim1 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 1)); + INDArray alongDim0 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 0))[0]; + INDArray alongDim1 = Nd4j.getExecutioner().exec(new IsMax(orig.dup(), Nd4j.createUninitialized(DataType.BOOL, orig.shape()), 1))[0]; INDArray expAlong0 = Nd4j.create(new boolean[]{true, true, true, true}).reshape(1,4); INDArray expAlong1 = Nd4j.create(new boolean[] {false, false, true, false}).reshape(1,4); @@ -841,8 +874,8 @@ public class Nd4jTestsC extends BaseNd4jTest { //1d: col vector System.out.println("----------------------------------"); INDArray col = Nd4j.create(new double[] {1, 2, 3, 1}, new long[] {4, 1}); - INDArray alongDim0col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()), 0)); - INDArray alongDim1col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()),1)); + INDArray alongDim0col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()), 0))[0]; + INDArray alongDim1col = Nd4j.getExecutioner().exec(new IsMax(col.dup(), Nd4j.createUninitialized(DataType.BOOL, col.shape()),1))[0]; INDArray expAlong0col = Nd4j.create(new boolean[] {false, false, true, false}).reshape(4,1); INDArray expAlong1col = Nd4j.create(new boolean[] {true, true, true, true}).reshape(4,1); @@ -877,10 +910,10 @@ public class Nd4jTestsC extends BaseNd4jTest { //[0 1 0] System.out.println("---------------------"); INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}}); - INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0)); - INDArray alongDim0f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 0)); - INDArray alongDim1c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 1)); - INDArray alongDim1f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 1)); + INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0))[0]; + INDArray alongDim0f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 0))[0]; + INDArray alongDim1c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 1))[0]; + INDArray alongDim1f_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('f'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape(), 'f'), 1))[0]; INDArray expAlong0_2d = Nd4j.create(new boolean[][] {{false, false, true}, {true, true, false}}); INDArray expAlong1_2d = Nd4j.create(new boolean[][] {{false, false, true}, {false, true, false}}); @@ -904,7 +937,7 @@ public class Nd4jTestsC extends BaseNd4jTest { @Test public void testIsMaxSingleDim1() { INDArray orig2d = Nd4j.create(new double[][] {{1, 0, 2}, {2, 3, 1}}); - INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0)); + INDArray alongDim0c_2d = Nd4j.getExecutioner().exec(new IsMax(orig2d.dup('c'), Nd4j.createUninitialized(DataType.BOOL, orig2d.shape()), 0))[0]; INDArray expAlong0_2d = Nd4j.create(new boolean[][] {{false, false, true}, {true, true, false}}); System.out.println("Original shapeInfo: " + orig2d.dup('c').shapeInfoDataBuffer()); @@ -1056,8 +1089,8 @@ public class Nd4jTestsC extends BaseNd4jTest { + Arrays.toString(shape) + ")"); INDArray arrC = Nd4j.linspace(1, length, length, DataType.DOUBLE).reshape('c', shape); INDArray arrF = arrC.dup('f'); - val resC = Nd4j.getExecutioner().exec(new IsMax(arrC, alongDimension)); - val resF = Nd4j.getExecutioner().exec(new IsMax(arrF, alongDimension)); + val resC = Nd4j.getExecutioner().exec(new IsMax(arrC, alongDimension))[0]; + val resF = Nd4j.getExecutioner().exec(new IsMax(arrF, alongDimension))[0]; double[] cBuffer = resC.data().asDouble(); @@ -3932,7 +3965,7 @@ public class Nd4jTestsC extends BaseNd4jTest { v.assign(t); } - val result = Nd4j.getExecutioner().exec(new IsMax(arr, Nd4j.createUninitialized(DataType.BOOL, arr.shape(), arr.ordering()), 1, 2)); + val result = Nd4j.getExecutioner().exec(new IsMax(arr, Nd4j.createUninitialized(DataType.BOOL, arr.shape(), arr.ordering()), 1, 2))[0]; assertEquals(expected, result); } @@ -3971,8 +4004,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } } - INDArray actC = Nd4j.getExecutioner().exec(new IsMax(arr.dup('c'), Nd4j.createUninitialized(DataType.BOOL, arr.shape()),0, 1)); - INDArray actF = Nd4j.getExecutioner().exec(new IsMax(arr.dup('f'), Nd4j.createUninitialized(DataType.BOOL, arr.shape(), 'f'), 0, 1)); + INDArray actC = Nd4j.getExecutioner().exec(new IsMax(arr.dup('c'), Nd4j.createUninitialized(DataType.BOOL, arr.shape()),0, 1))[0]; + INDArray actF = Nd4j.getExecutioner().exec(new IsMax(arr.dup('f'), Nd4j.createUninitialized(DataType.BOOL, arr.shape(), 'f'), 0, 1))[0]; assertEquals(exp, actC); assertEquals(exp, actF); @@ -4006,8 +4039,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } } - actC = Nd4j.getExecutioner().exec(new IsMax(arr.dup('c'), 2, 3)); - actF = Nd4j.getExecutioner().exec(new IsMax(arr.dup('f'), 2, 3)); + actC = Nd4j.getExecutioner().exec(new IsMax(arr.dup('c'), arr.dup('c').ulike(), 2, 3))[0]; + actF = Nd4j.getExecutioner().exec(new IsMax(arr.dup('f'), arr.dup('f').ulike(), 2, 3))[0]; assertEquals(exp, actC); assertEquals(exp, actF); @@ -6527,7 +6560,7 @@ public class Nd4jTestsC extends BaseNd4jTest { assertTrue(x.sumNumber().floatValue() > 0); x = Nd4j.randn(DataType.BFLOAT16 , 10); - assertTrue(x.sumNumber().floatValue() > 0); + assertTrue(x.sumNumber().floatValue() != 0.0); } @Test @@ -7962,7 +7995,7 @@ public class Nd4jTestsC extends BaseNd4jTest { public void testBatchToSpace(){ INDArray out = Nd4j.create(DataType.FLOAT, 2, 4, 5); - DynamicCustomOp c = new BatchToSpace(); + DynamicCustomOp c = new BatchToSpaceND(); c.addInputArgument( Nd4j.rand(DataType.FLOAT, new int[]{4, 4, 3}), @@ -7979,6 +8012,14 @@ public class Nd4jTestsC extends BaseNd4jTest { //from [4,4,3] to [2,4,6] then crop to [2,4,5] } + @Test + public void testToFromByteArray() throws IOException { + // simple test to get rid of toByteArray and fromByteArray compiler warnings. + INDArray x = Nd4j.arange(10); + byte[] xb = Nd4j.toByteArray(x); + INDArray y = Nd4j.fromByteArray(xb); + assertEquals(x,y); + } private static INDArray fwd(INDArray input, INDArray W, INDArray b){ INDArray ret = Nd4j.createUninitialized(input.size(0), W.size(1)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java index 1bf709a5f..729c18c77 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/api/rng/RngTests.java @@ -102,7 +102,7 @@ public class RngTests extends BaseNd4jTest { } @Test - void testRandomBinomial() { + public void testRandomBinomial() { //silly tests. Just increasing the usage for randomBinomial to stop compiler warnings. INDArray x = Nd4j.randomBinomial(10, 0.5, 3,3); assertTrue(x.sum().getDouble(0) > 0.0); //silly test. Just increasing th usage for randomBinomial diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/CudaTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/CudaTests.java index 02375c873..ad9d360c5 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/CudaTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/specials/CudaTests.java @@ -106,115 +106,6 @@ public class CudaTests extends BaseNd4jTest { assertEquals(exp, arrayA); } - @Test(timeout = 40000L) - public void testContextSpam() throws Exception { - if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA) - return; - - val success = new AtomicInteger(0); - val iterations = 101; - - val threads = new ArrayList(); - for (int e = 0; e < iterations; e++) { - val f = e; - val t = new Thread(new Runnable() { - @Override - public void run() { - Nd4j.create(1); - if (f % 50 == 0) - log.info("Context {} created", f); - - Nd4j.getMemoryManager().releaseCurrentContext(); - success.incrementAndGet(); - try { - Thread.sleep(1000L); - } catch (InterruptedException ex) { - ex.printStackTrace(); - } - } - }); - - t.start(); - threads.add(t); - } - - for (val t: threads) - t.join(); - - assertEquals(iterations, success.get()); - } - - @Ignore - @Test(timeout = 180000L) - public void testContextSpam_2() throws Exception { - if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA) - return; - - val success = new AtomicInteger(0); - val iterations = 101; - - val threads = new ArrayList(); - for (int e = 0; e < iterations; e++) { - val f = e; - val t = new Thread(new Runnable() { - @Override - public void run() { - Nd4j.create(1); - if (f % 50 == 0) - log.info("Context {} created", f); - - //Nd4j.getMemoryManager().releaseCurrentContext(); - success.incrementAndGet(); - try { - Thread.sleep(1000L); - } catch (InterruptedException ex) { - ex.printStackTrace(); - } - } - }); - - t.start(); - threads.add(t); - } - - for (val t: threads) - t.join(); - - assertEquals(iterations, success.get()); - } - - @Test - public void testSequentialReleaseAndReacquire() throws Exception { - if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA) - return; - - Nd4j.create(128); - - Nd4j.getMemoryManager().releaseCurrentContext(); - - val array = Nd4j.create(128); - array.addi(1.0f); - } - - @Test - @Ignore - public void test(){ - if (Nd4j.getExecutioner().type() != OpExecutioner.ExecutionerType.CUDA) - return; - - val SD = SameDiff.create(); - val in = SD.one("test", 5, 8, 3, 4); - SDVariable out = in.reshape(-1, 4); - SDVariable out1 = out.reshape(4, 15, -1); - SDVariable out2 = SD.dot(out1, out1, 2); - - SDVariable out3 = out2.reshape(-1, 4); // <---- error here - - System.out.println(Arrays.toString(out3.eval().toFloatMatrix())); - - } - - @Override public char ordering() { return 'c'; diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml index fd9cc9285..f0d44beb6 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/pom.xml @@ -27,6 +27,18 @@ jar nd4j-parameter-server-node + + + + org.apache.maven.plugins + maven-compiler-plugin + + 8 + 8 + + + + diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/VoidParameterServer.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/VoidParameterServer.java index 751fa79c3..26e46949f 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/VoidParameterServer.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/VoidParameterServer.java @@ -19,6 +19,7 @@ package org.nd4j.parameterserver.distributed; import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.nd4j.config.ND4JEnvironmentVars; import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.api.ndarray.INDArray; @@ -276,9 +277,12 @@ public class VoidParameterServer { processingThreads = new Thread[numThreads]; processingRunnables = new Runnable[numThreads]; + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + for (int x = 0; x < numThreads; x++) { processingThreads[x] = new Thread(() -> { runner.set(true); + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); while (runner.get()) { try { //VoidMessage message = transport.takeMessage(); @@ -296,11 +300,6 @@ public class VoidParameterServer { } }); - //executor.submit(processingRunnables[x); - - // TODO: maybe find the way to guarantee affinity in some other way, to make different devices usable as well? - Nd4j.getAffinityManager().attachThreadToDevice(processingThreads[x], - Nd4j.getAffinityManager().getDeviceForCurrentThread()); processingThreads[x].setDaemon(true); processingThreads[x].setName("VoidParameterServer messages handling thread"); processingThreads[x].start(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java index a1de47a15..dbd6545d6 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/transport/BaseTransport.java @@ -25,6 +25,7 @@ import io.aeron.logbuffer.Header; import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; +import lombok.val; import org.agrona.CloseHelper; import org.agrona.DirectBuffer; import org.agrona.concurrent.IdleStrategy; @@ -349,6 +350,8 @@ public abstract class BaseTransport implements Transport { * only because we want code to be obvious for people */ final AtomicBoolean localRunner = new AtomicBoolean(false); + val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); + if (nodeRole == NodeRole.NONE) { throw new ND4JIllegalStateException("No role is set for current node!"); } else if (nodeRole == NodeRole.SHARD || nodeRole == NodeRole.BACKUP || nodeRole == NodeRole.MASTER) { @@ -357,6 +360,7 @@ public abstract class BaseTransport implements Transport { // setting up thread for shard->client communication listener if (messageHandlerForShards != null) { threadB = new Thread(() -> { + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); while (runner.get()) idler.idle(subscriptionForShards.poll(messageHandlerForShards, 512)); @@ -368,13 +372,13 @@ public abstract class BaseTransport implements Transport { // setting up thread for inter-shard communication listener threadA = new Thread(() -> { localRunner.set(true); + Nd4j.getAffinityManager().unsafeSetDevice(deviceId); while (runner.get()) idler.idle(subscriptionForClients.poll(messageHandlerForClients, 512)); }); if (threadB != null) { - Nd4j.getAffinityManager().attachThreadToDevice(threadB, - Nd4j.getAffinityManager().getDeviceForCurrentThread()); + threadB.setDaemon(true); threadB.setName("VoidParamServer subscription threadB [" + nodeRole + "]"); threadB.start(); @@ -389,8 +393,7 @@ public abstract class BaseTransport implements Transport { } // all roles have threadA anyway - Nd4j.getAffinityManager().attachThreadToDevice(threadA, - Nd4j.getAffinityManager().getDeviceForCurrentThread()); + //Nd4j.getAffinityManager().attachThreadToDevice(threadA, Nd4j.getAffinityManager().getDeviceForCurrentThread()); threadA.setDaemon(true); threadA.setName("VoidParamServer subscription threadA [" + nodeRole + "]"); threadA.start(); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java index b609d7378..a34720093 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/AeronUdpTransport.java @@ -150,11 +150,19 @@ public class AeronUdpTransport extends BaseTransport implements AutoCloseable { // this executor service han protected ExecutorService messagesExecutorService = Executors.newFixedThreadPool(SENDER_THREADS + MESSAGE_THREADS + SUBSCRIPTION_THREADS, new ThreadFactory() { @Override - public Thread newThread(@NonNull Runnable r) { - val t = Executors.defaultThreadFactory().newThread(r); + public Thread newThread(@NonNull final Runnable r) { + val t = new Thread(new Runnable() { + + @Override + public void run() { + //TODO implement support for multi-GPU masters + Nd4j.getAffinityManager().unsafeSetDevice(0); //Associate thread with device 0 (no-op for CPU) + r.run(); + } + }); + t.setDaemon(true); - //TODO implement support for multi-GPU masters - Nd4j.getAffinityManager().attachThreadToDevice(t, 0); //Associate thread with device 0 (no-op for CPU) + t.setName("MessagesExecutorService thread"); return t; } }); diff --git a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.java b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.java index 07d8ef096..34bceb6c2 100644 --- a/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.java +++ b/nd4j/nd4j-parameter-server-parent/nd4j-parameter-server-node/src/main/java/org/nd4j/parameterserver/distributed/v2/transport/impl/BaseTransport.java @@ -107,10 +107,16 @@ public abstract class BaseTransport implements Transport { protected final ThreadPoolExecutor executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(Math.max(2, Runtime.getRuntime().availableProcessors()), new ThreadFactory() { @Override - public Thread newThread(@NonNull Runnable r) { - val t = Executors.defaultThreadFactory().newThread(r); + public Thread newThread(@NonNull final Runnable r) { + val t = new Thread(new Runnable() { + @Override + public void run() { + Nd4j.getAffinityManager().unsafeSetDevice(0); + r.run(); + } + }); + t.setDaemon(true); - Nd4j.getAffinityManager().attachThreadToDevice(t, 0); return t; } }); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java index 8866f0c40..24e6bfcb9 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/AsyncLearning.java @@ -45,7 +45,7 @@ public abstract class AsyncLearning getAsyncGlobal(); @@ -60,9 +60,7 @@ public abstract class AsyncLearning asyncGlobal, int threadNumber) { + public AsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, int deviceNum) { this.threadNumber = threadNumber; + this.deviceNum = deviceNum; } public void setHistoryProcessor(IHistoryProcessor.Configuration conf) { @@ -87,6 +91,7 @@ public abstract class AsyncThread asyncGlobal, int threadNumber) { - super(asyncGlobal, threadNumber); + public AsyncThreadDiscrete(IAsyncGlobal asyncGlobal, int threadNumber, int deviceNum) { + super(asyncGlobal, threadNumber, deviceNum); synchronized (asyncGlobal) { current = (NN)asyncGlobal.getCurrent().clone(); } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java index 5777e2394..7dbec6210 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CDiscrete.java @@ -62,9 +62,9 @@ public abstract class A3CDiscrete extends AsyncLearning extends A3CDiscrete { } @Override - public AsyncThread newThread(int i) { - AsyncThread at = super.newThread(i); + public AsyncThread newThread(int i, int deviceNum) { + AsyncThread at = super.newThread(i, deviceNum); at.setHistoryProcessor(hpconf); return at; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java index 4c5873b11..3a481b09c 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/a3c/discrete/A3CThreadDiscrete.java @@ -57,8 +57,8 @@ public class A3CThreadDiscrete extends AsyncThreadDiscrete< final private Random random; public A3CThreadDiscrete(MDP mdp, AsyncGlobal asyncGlobal, - A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager) { - super(asyncGlobal, threadNumber); + A3CDiscrete.A3CConfiguration a3cc, int threadNumber, IDataManager dataManager, int deviceNum) { + super(asyncGlobal, threadNumber, deviceNum); this.conf = a3cc; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java index f0d7a3349..bab60fec4 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscrete.java @@ -55,9 +55,9 @@ public abstract class AsyncNStepQLearningDiscrete mdp.getActionSpace().setSeed(conf.getSeed()); } - - public AsyncThread newThread(int i) { - return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager); + @Override + public AsyncThread newThread(int i, int deviceNum) { + return new AsyncNStepQLearningThreadDiscrete(mdp.newInstance(), asyncGlobal, configuration, i, dataManager, deviceNum); } public IDQN getNeuralNet() { diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java index 4da14012e..257e5fb5d 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningDiscreteConv.java @@ -53,8 +53,8 @@ public class AsyncNStepQLearningDiscreteConv extends AsyncN } @Override - public AsyncThread newThread(int i) { - AsyncThread at = super.newThread(i); + public AsyncThread newThread(int i, int deviceNum) { + AsyncThread at = super.newThread(i, deviceNum); at.setHistoryProcessor(hpconf); return at; } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java index 4f6c3ad09..23d6f79ca 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/async/nstep/discrete/AsyncNStepQLearningThreadDiscrete.java @@ -56,8 +56,8 @@ public class AsyncNStepQLearningThreadDiscrete extends Asyn public AsyncNStepQLearningThreadDiscrete(MDP mdp, IAsyncGlobal asyncGlobal, AsyncNStepQLearningDiscrete.AsyncNStepQLConfiguration conf, int threadNumber, - IDataManager dataManager) { - super(asyncGlobal, threadNumber); + IDataManager dataManager, int deviceNum) { + super(asyncGlobal, threadNumber, deviceNum); this.conf = conf; this.asyncGlobal = asyncGlobal; this.threadNumber = threadNumber; diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java index ead3f4d00..fd7d9465b 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/ExpReplay.java @@ -16,6 +16,8 @@ package org.deeplearning4j.rl4j.learning.sync; +import it.unimi.dsi.fastutil.ints.IntOpenHashSet; +import it.unimi.dsi.fastutil.ints.IntSet; import lombok.extern.slf4j.Slf4j; import org.apache.commons.collections4.queue.CircularFifoQueue; @@ -50,7 +52,18 @@ public class ExpReplay implements IExpReplay { ArrayList> batch = new ArrayList<>(size); int storageSize = storage.size(); int actualBatchSize = Math.min(storageSize, size); - int[] actualIndex = ThreadLocalRandom.current().ints(0, storageSize).distinct().limit(actualBatchSize).toArray(); + + int[] actualIndex = new int[actualBatchSize]; + ThreadLocalRandom r = ThreadLocalRandom.current(); + IntSet set = new IntOpenHashSet(); + for( int i=0; i trans = storage.get(actualIndex[i]); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java index a2c25a43c..525995455 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QLearning.java @@ -50,7 +50,7 @@ public abstract class QLearning expReplay; @Getter @Setter(AccessLevel.PACKAGE) - private IExpReplay expReplay; + protected IExpReplay expReplay; public QLearning(QLConfiguration conf) { super(conf); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java index 0e53103ef..23be44f01 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/async/AsyncThreadTest.java @@ -194,7 +194,7 @@ public class AsyncThreadTest { private final IDataManager dataManager; public MockAsyncThread(IAsyncGlobal asyncGlobal, int threadNumber, MockNeuralNet neuralNet, MDP mdp, AsyncConfiguration conf, IDataManager dataManager) { - super(asyncGlobal, threadNumber); + super(asyncGlobal, threadNumber, 0); this.asyncGlobal = asyncGlobal; this.neuralNet = neuralNet; diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java index 51bdeaf41..5762875aa 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/QLearningDiscreteTest.java @@ -1,6 +1,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete; import org.deeplearning4j.rl4j.learning.IHistoryProcessor; +import org.deeplearning4j.rl4j.learning.sync.IExpReplay; import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.qlearning.QLearning; import org.deeplearning4j.rl4j.mdp.MDP; @@ -138,5 +139,10 @@ public class QLearningDiscreteTest { protected Pair setTarget(ArrayList> transitions) { return new Pair<>(Nd4j.create(new double[] { 123.0 }), Nd4j.create(new double[] { 234.0 })); } + + public void setExpReplay(IExpReplay exp){ + this.expReplay = exp; + } + } }