diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java
index 743e16710..b0cc17376 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java
@@ -57,10 +57,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer;
-import org.junit.AfterClass;
-import org.junit.Before;
-import org.junit.BeforeClass;
-import org.junit.Test;
+import org.junit.*;
+import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType;
@@ -82,6 +80,7 @@ import org.nd4j.common.resources.Resources;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
+import java.io.File;
import java.io.IOException;
import java.util.*;
@@ -91,6 +90,9 @@ import static org.junit.Assert.*;
@Slf4j
public class TestComputationGraphNetwork extends BaseDL4JTest {
+ @Rule
+ public TemporaryFolder testDir = new TemporaryFolder();
+
private static ComputationGraphConfiguration getIrisGraphConfiguration() {
return new NeuralNetConfiguration.Builder().seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
@@ -2177,4 +2179,40 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2);
cg.fit(new DataSet(in, label));
}
+
+ @Test
+ public void testMergeNchw() throws Exception {
+ ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
+ .convolutionMode(ConvolutionMode.Same)
+ .graphBuilder()
+ .addInputs("in")
+ .layer("l0", new ConvolutionLayer.Builder()
+ .nOut(16)
+ .kernelSize(2,2).stride(1,1)
+ .build(), "in")
+ .layer("l1", new ConvolutionLayer.Builder()
+ .nOut(8)
+ .kernelSize(2,2).stride(1,1)
+ .build(), "in")
+ .addVertex("merge", new MergeVertex(), "l0", "l1")
+ .layer("out", new CnnLossLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge")
+ .setOutputs("out")
+ .setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC))
+ .build();
+
+ ComputationGraph cg = new ComputationGraph(conf);
+ cg.init();
+
+ INDArray[] in = new INDArray[]{Nd4j.rand(DataType.FLOAT, 1, 32, 32, 3)};
+ INDArray out = cg.outputSingle(in);
+
+ File dir = testDir.newFolder();
+ File f = new File(dir, "net.zip");
+ cg.save(f);
+
+ ComputationGraph c2 = ComputationGraph.load(f, true);
+ INDArray out2 = c2.outputSingle(in);
+
+ assertEquals(out, out2);
+ }
}
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java
index 46a872fd8..acb6afa2c 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java
@@ -66,8 +66,8 @@ import org.nd4j.shade.jackson.annotation.JsonProperty;
* @author Alex Black
*/
@Data
-@JsonIgnoreProperties({"mask", "helper", "helperCountFail"})
-@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail"})
+@JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"})
+@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"})
@Slf4j
public class Dropout implements IDropout {
diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java
index 726a68403..c7a4fec63 100644
--- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java
+++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java
@@ -17,6 +17,7 @@
package org.deeplearning4j.nn.conf.graph;
+import lombok.Data;
import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat;
@@ -38,6 +39,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
* -> [numExamples,depth1 + depth2,width,height]}
* @author Alex Black
*/
+@Data
public class MergeVertex extends GraphVertex {
protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format
diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java
index 68a012b72..c1eff1dce 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java
@@ -141,7 +141,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats());
- System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
+// System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
File f = testDir.newFolder();
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
int count = 0;
@@ -208,10 +208,10 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
}
INDArray paramsAfter = after.params();
- System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
- System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
- System.out.println(Arrays.toString(
- Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
+// System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
+// System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
+// System.out.println(Arrays.toString(
+// Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
assertNotEquals(paramsBefore, paramsAfter);
@@ -235,7 +235,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
}
- @Test
+ @Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985
public void differentNetsTrainingTest() throws Exception {
int batch = 3;
diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt
index 106401b31..0631763c2 100755
--- a/libnd4j/CMakeLists.txt
+++ b/libnd4j/CMakeLists.txt
@@ -131,6 +131,23 @@ if(NOT SD_CUDA)
endif()
endif()
+#arm-compute entry
+if(${HELPERS_armcompute})
+ find_package(ARMCOMPUTE REQUIRED)
+
+ if(ARMCOMPUTE_FOUND)
+ message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}")
+ set(HAVE_ARMCOMPUTE 1)
+ # Add preprocessor definition for ARM Compute NEON
+ add_definitions(-DARMCOMPUTENEON_ENABLED)
+ #build our library with neon support
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon")
+ include_directories(${ARMCOMPUTE_INCLUDE})
+ message("----${ARMCOMPUTE_INCLUDE}---")
+ endif()
+
+endif()
+
# new mkl-dnn entry
if (${HELPERS_mkldnn})
diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt
index fb1dc066e..b6bd1f7c0 100755
--- a/libnd4j/blas/CMakeLists.txt
+++ b/libnd4j/blas/CMakeLists.txt
@@ -146,6 +146,10 @@ if (HAVE_MKLDNN)
file(GLOB_RECURSE CUSTOMOPS_MKLDNN_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h)
endif()
+if(HAVE_ARMCOMPUTE)
+ file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/*.h)
+endif()
+
if(SD_CUDA)
message("Build cublas")
find_package(CUDA)
@@ -243,7 +247,7 @@ if(SD_CUDA)
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
- ${CUSTOMOPS_GENERIC_SOURCES}
+ ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
)
if (WIN32)
@@ -351,8 +355,8 @@ elseif(SD_CPU)
add_definitions(-D__CPUBLAS__=true)
add_library(samediff_obj OBJECT ${LEGACY_SOURCES}
${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
- ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
- ${OPS_SOURCES} ${PERF_SOURCES})
+ ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
+ ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
if(IOS)
add_library(${SD_LIBRARY_NAME} STATIC $)
else()
@@ -378,12 +382,12 @@ elseif(SD_CPU)
if (NOT BLAS_LIBRARIES)
set(BLAS_LIBRARIES "")
endif()
- target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
+ target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
message(STATUS "Building minifier...")
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
- target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
+ target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
diff --git a/libnd4j/cmake/FindARMCOMPUTE.cmake b/libnd4j/cmake/FindARMCOMPUTE.cmake
new file mode 100644
index 000000000..ae0e1fbba
--- /dev/null
+++ b/libnd4j/cmake/FindARMCOMPUTE.cmake
@@ -0,0 +1,74 @@
+################################################################################
+# Copyright (c) 2020 Konduit K.K.
+#
+# This program and the accompanying materials are made available under the
+# terms of the Apache License, Version 2.0 which is available at
+# https://www.apache.org/licenses/LICENSE-2.0.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations
+# under the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+################################################################################
+
+
+
+### Find ARM COMPUTE LIBRARY STATIC libraries
+
+SET (COMPUTE_INCLUDE_DIRS
+ /usr/include
+ ${ARMCOMPUTE_ROOT}
+ ${ARMCOMPUTE_ROOT}/include
+ ${ARMCOMPUTE_ROOT}/applications
+ ${ARMCOMPUTE_ROOT}/applications/arm_compute
+)
+
+
+SET (COMPUTE_LIB_DIRS
+ /lib
+ /usr/lib
+ ${ARMCOMPUTE_ROOT}
+ ${ARMCOMPUTE_ROOT}/lib
+ ${ARMCOMPUTE_ROOT}/build
+)
+
+find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h
+ PATHS ${COMPUTE_INCLUDE_DIRS}
+ NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
+
+find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h)
+
+find_path(HALF_INCLUDE half/half.hpp)
+find_path(HALF_INCLUDE half/half.hpp
+ PATHS ${ARMCOMPUTE_ROOT}/include
+ NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
+include_directories(SYSTEM ${HALF_INCLUDE})
+
+# Find the Arm Compute libraries if not already specified
+if (NOT DEFINED ARMCOMPUTE_LIBRARIES)
+
+ find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static
+ PATHS ${COMPUTE_LIB_DIRS}
+ PATH_SUFFIXES "Release"
+ NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
+
+ find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static
+ PATHS ${COMPUTE_LIB_DIRS}
+ PATH_SUFFIXES "Release"
+ NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
+ # In case it wasn't there, try a default search (will work in cases where
+ # the library has been installed into a standard location)
+ find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static)
+ find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static)
+
+ set(ARMCOMPUTE_LIBRARIES ${ARMCOMPUTE_LIBRARY} ${ARMCOMPUTE_CORE_LIBRARY} )
+endif()
+
+
+INCLUDE(FindPackageHandleStandardArgs)
+
+FIND_PACKAGE_HANDLE_STANDARD_ARGS(ARMCOMPUTE REQUIRED_VARS ARMCOMPUTE_INCLUDE ARMCOMPUTE_LIBRARIES)
+
diff --git a/libnd4j/include/cnpy/cnpy.h b/libnd4j/include/cnpy/cnpy.h
index ea847c3e7..c84623599 100644
--- a/libnd4j/include/cnpy/cnpy.h
+++ b/libnd4j/include/cnpy/cnpy.h
@@ -69,7 +69,7 @@ namespace cnpy {
}
};
- struct ND4J_EXPORT npz_t : public std::unordered_map {
+ struct ND4J_EXPORT npz_t : public std::map {
void destruct() {
npz_t::iterator it = this->begin();
for(; it != this->end(); ++it) (*it).second.destruct();
diff --git a/libnd4j/include/config.h.in b/libnd4j/include/config.h.in
index 1e63552d0..c858dd765 100644
--- a/libnd4j/include/config.h.in
+++ b/libnd4j/include/config.h.in
@@ -3,6 +3,8 @@
#cmakedefine HAVE_MKLDNN
+#cmakedefine HAVE_ARMCOMPUTE
+
#cmakedefine MKLDNN_PATH "@MKLDNN_PATH@"
#cmakedefine HAVE_OPENBLAS
diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp
index eced3c2b4..b03d19451 100644
--- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp
+++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp
@@ -45,18 +45,18 @@ namespace sd {
DECLARE_TYPES(max_pool_with_argmax) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
- ->setAllowedOutputTypes(0, DataType::INHERIT)
- ->setAllowedOutputTypes(1, {ALL_INTS});
+ ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS})
+ ->setAllowedOutputTypes(1, {ALL_INDICES});
}
DECLARE_SHAPE_FN(max_pool_with_argmax) {
+ auto in = inputShape->at(0);
+ auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64;
+ auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
+ auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, dtype));
- auto in = inputShape->at(0);
- auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
- auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, DataType::INT64));
-
- return SHAPELIST(valuesShape, indicesShape);
+ return SHAPELIST(valuesShape, indicesShape);
}
}
}
diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp
index 8f45c696b..7e66d4b11 100644
--- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp
+++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp
@@ -215,7 +215,9 @@ namespace helpers {
auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]);
auto result = -1;
//auto loop = PRAGMA_THREADS_FOR {
- auto start = column, stop = rowNum, increment = 1;
+ auto start = column;
+ auto stop = rowNum;
+ auto increment = 1;
for (auto rowCounter = start; rowCounter < stop; rowCounter++) {
Nd4jLong xPos[] = {rowCounter, column};
auto xIndex = shape::getOffset(compoundShape, xPos, 0);
diff --git a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp
index a458b5eff..ebb9d53fa 100644
--- a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp
+++ b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp
@@ -73,7 +73,7 @@ namespace helpers {
}
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) {
- BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES);
+ BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES);
}
}
diff --git a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp
index ea529112d..2ffbfc95f 100644
--- a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp
+++ b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp
@@ -16,7 +16,8 @@
//
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
-//
+// implementation is based on following article:
+// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167
@@ -31,96 +32,167 @@ namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
+// Fisher-Yates shuffle
template
-void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
+static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) {
+
+ for(Nd4jLong i = len-1; i > 0; --i) {
+ const Nd4jLong j = rng.relativeLong(ind++) % (i + 1);
+ if(i != j)
+ math::nd4j_swap(buff[i*ews], buff[j*ews]);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+// mutual shuffle of two adjacent already shuffled ranges with length len1 and (totLen - len1) correspondingly
+template
+static void mergeShuffle(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len1, const Nd4jLong& totLen, const Nd4jLong& ews, Nd4jLong ind) {
+
+ Nd4jLong beg = 0; // beginning
+ Nd4jLong mid = len1; // middle
+
+ while (true) {
+ if(rng.relativeLong(ind++) % 2) {
+ if(mid == totLen)
+ break;
+ math::nd4j_swap(buff[ews * beg], buff[ews * mid++]);
+ } else {
+ if(beg == mid)
+ break;
+ }
+ ++beg;
+ }
+
+ // fisherYates
+ while (beg < totLen) {
+ const Nd4jLong j = rng.relativeLong(ind++) % (beg + 1);
+ if(beg != j)
+ math::nd4j_swap(buff[ews * beg], buff[ews * j]);
+ ++beg;
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+template
+static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
- // check edge cases first
- int temp;
const int firstDim = input.sizeAt(0);
+ int temp;
+
if(input.lengthOf() == 1 || firstDim == 1) {
if(!isInplace)
output.assign(input);
}
- else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) {
+ else if (shape::isCommonVector(input.shapeInfo(), temp)) {
- // apply Fisher-Yates shuffle
- if(isInplace) {
- //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
- for(int i = firstDim-1; i > 0; --i) {
- int r = rng.relativeInt(i) % i;
- if(i == r)
- continue;
- T t0 = input.t(i);
- T t1 = input.t(r);
- //math::nd4j_swap(input(i), input(r));
- input.r(i) = t1;
- input.r(r) = t0;
- }
+ NDArray* arr = &input;
+
+ if (!isInplace) {
+ output.assign(input);
+ arr = &output;
}
- else {
- std::vector indices(firstDim);
- std::iota(indices.begin(), indices.end(), 0);
- output.p(Nd4jLong(0), input.e(0));
- // FIXME: parallelism!!
- for(int i = firstDim-1; i > 0; --i) {
- int r = rng.relativeInt(i) % i;
- output.r(i) = input.t(indices[r]);
- if(i == r)
- continue;
+ const Nd4jLong ews = arr->ews();
- output.r(r) = input.t(indices[i]);
- math::nd4j_swap(indices[i], indices[r]);
+ const Nd4jLong len = arr->lengthOf();
+ const Nd4jLong threshold = 1<<22; // this number was deduced from diagram in article
+
+ int power = 0;
+ while ((len >> power) > threshold)
+ ++power;
+
+ const Nd4jLong numChunks = 1 << power;
+
+ auto funcFisherYates = PRAGMA_THREADS_FOR {
+
+ for (auto i = start; i < stop; ++i) {
+
+ Nd4jLong offset = (len * i) >> power;
+ Nd4jLong currLen = ((len * (i + 1)) >> power) - offset;
+ fisherYates(rng, arr->bufferAsT() + offset*ews, currLen, ews, offset);
}
- rng.rewindH(firstDim-1);
- }
+ };
+
+ auto funcMerge = PRAGMA_THREADS_FOR {
+
+ for (int64_t i = start, k = 1; i < stop; i += increment, ++k) {
+ Nd4jLong offset = len * i >> power;
+ Nd4jLong len1 = (len * (i + increment/2) >> power) - offset;
+ Nd4jLong totLen = (len * (i + increment) >> power) - offset;
+ mergeShuffle(rng, arr->bufferAsT() + offset*ews, len1, totLen, ews, len * k + offset);
+ }
+ };
+
+ samediff::Threads::parallel_for(funcFisherYates, 0, numChunks);
+
+ for (int j = 1; j < numChunks; j += j)
+ samediff::Threads::parallel_for(funcMerge, 0, numChunks, 2*j);
+
+ // #pragma omp parallel for
+ // for (uint i = 0; i < numChunks; ++i) {
+
+ // Nd4jLong offset = (len * i) >> power;
+ // Nd4jLong currLen = ((len * (i + 1)) >> power) - offset;
+ // fisherYates(rng, arr->bufferAsT() + offset*ews, currLen, ews, offset);
+ // }
+
+ // for (uint j = 1; j < numChunks; j += j) {
+ // #pragma omp parallel for
+ // for (auto i = 0; i < numChunks; i += 2*j) {
+ // Nd4jLong offset = len * i >> power;
+ // Nd4jLong len1 = (len * (i + j) >> power) - offset;
+ // Nd4jLong totLen = (len * (i + 2*j) >> power) - offset;
+ // mergeShuffle(rng, arr->bufferAsT() + offset*ews, len1, totLen, ews, len * j + offset);
+ // }
+ // }
+
+ rng.rewindH((len + 1) * power);
}
else {
- // evaluate sub-arrays list of input array through all dimensions excluding first one
- std::vector dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
- auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
+ auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
- // apply Fisher-Yates shuffle
if(isInplace) {
- //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().elementwiseThreshold())
- for(int i = firstDim - 1; i > 0; --i) {
- int r = rng.relativeInt(i) % i;
- if(i == r)
- continue;
- subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r));
+ auto subArrsList = input.allTensorsAlongDimension(dimsToExclude);
+
+ // Fisher-Yates shuffle
+ for(int i = firstDim - 1; i > 0; --i) {
+ const int j = rng.relativeInt(i) % (i + 1);
+ if(i != j)
+ subArrsList.at(i)->swapUnsafe(*subArrsList.at(j));
}
}
else {
- // evaluate sub-arrays list of output array through all dimensions excluding first one
- auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
+
+ auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude);
+ auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude);
+
std::vector indices(firstDim);
- std::iota(indices.begin(), indices.end(), 0);
- bool isZeroShuffled = false;
- //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
- for(int i = firstDim - 1; i > 0; --i) {
- int r = rng.relativeInt(i) % i;
- subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r]));
- if(r == 0)
- isZeroShuffled = true;
- if(i == r)
- continue;
- subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i]));
- math::nd4j_swap(indices[i], indices[r]);
- }
- if(!isZeroShuffled)
- subArrsListOut.at(0)->assign(subArrsListIn.at(0));
+ std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1
+
+ // shuffle indices
+ fisherYates(rng, indices.data(), firstDim, 1, 0);
+
+ auto func = PRAGMA_THREADS_FOR {
+
+ for (auto i = start; i < stop; ++i)
+ subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i]));
+ };
+
+ samediff::Threads::parallel_for(func, 0, firstDim);
}
+
rng.rewindH(firstDim-1);
}
-
}
- void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
- BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
- }
+void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
+ BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
+}
+
}
}
}
+
diff --git a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu
index cbcd35ffe..400c25f88 100644
--- a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu
+++ b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu
@@ -53,7 +53,7 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, const
int coords[MAX_RANK];
- for (uint64_t i = tid; i < zLen; i += totalThreads) {
+ for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
shape::index2coords(i, zShapeInfo, coords);
const auto zOffset = shape::getOffset(zShapeInfo, coords);
@@ -162,9 +162,9 @@ void concat(sd::LaunchContext * context, const std::vector& inAr
// }
// else { // general (slower) case
- const int threadsPerBlock = 256;
- const int blocksPerGrid = 512;
- const int sharedMem = 512;
+ const int threadsPerBlock = MAX_NUM_THREADS / 2;
+ const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
+ const int sharedMem = 256;
// prepare arrays of pointers on buffers and shapes
std::vector hInBuffers(numOfInArrs);
diff --git a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu
index 6e70d4510..8c30e510f 100644
--- a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu
+++ b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu
@@ -88,7 +88,7 @@ namespace helpers {
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) {
NDArray::prepareSpecialUse({values, indices}, {input});
auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType();
- BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES);
+ BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({values, indices}, {input});
}
diff --git a/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu b/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu
new file mode 100644
index 000000000..bb7998e60
--- /dev/null
+++ b/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu
@@ -0,0 +1,228 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+//
+// @author Yurii Shyrma (iuriish@yahoo.com)
+// implemented algorithm is GPU adaptation of algorithm described in following article:
+// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167
+//
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+namespace sd {
+namespace ops {
+namespace helpers {
+
+//////////////////////////////////////////////////////////////////////////
+template
+static __global__ void fisherYatesCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power) {
+
+ T* x = reinterpret_cast(vx);
+
+ __shared__ T* shmem, temp;
+ __shared__ Nd4jLong ind, blockOffset, lenPerBlock;
+
+ if (threadIdx.x == 0) {
+ extern __shared__ unsigned char sharedMemory[];
+ shmem = reinterpret_cast(sharedMemory);
+
+ blockOffset = (len * blockIdx.x) >> power;
+ lenPerBlock = ((len * (blockIdx.x + 1)) >> power) - blockOffset;
+ ind = blockOffset;
+ }
+ __syncthreads();
+
+ // copy from global memory to shared memory
+ if(threadIdx.x < lenPerBlock)
+ shmem[threadIdx.x] = x[(blockOffset + threadIdx.x) * ews];
+ __syncthreads();
+
+ // *** apply Fisher-Yates shuffle to lenPerBlock number of elements
+ if (threadIdx.x == 0) {
+ for(Nd4jLong i = lenPerBlock - 1; i > 0; --i) {
+ const Nd4jLong j = rng->relativeLong(ind++) % (i + 1);
+ if(i != j) {
+ temp = shmem[i];
+ shmem[i] = shmem[j];
+ shmem[j] = temp;
+ }
+ }
+ }
+ __syncthreads();
+
+ // copy from shared memory to global memory
+ if(threadIdx.x < lenPerBlock)
+ x[(blockOffset + threadIdx.x) * ews] = shmem[threadIdx.x];
+}
+
+template
+static __global__ void mergeShuffleCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power, const Nd4jLong iterNum) {
+
+
+ T* x = reinterpret_cast(vx);
+
+ __shared__ Nd4jLong ind, blockOffset, factor, beg, mid, totLen, iterExp;
+
+ // *** apply mergeShuffle algorithm
+ if(threadIdx.x == 0) {
+
+ factor = blockIdx.x << iterNum;
+ iterExp = 1 << (iterNum - 1);
+ blockOffset = (len * factor) >> power;
+ mid = ((len * (factor + iterExp)) >> power) - blockOffset; // middle
+ totLen = ((len * (factor + 2*iterExp)) >> power) - blockOffset;
+ ind = iterNum * len + blockOffset;
+ beg = 0; // beginning
+
+ // printf("m %lld, blockIdx.x %lld, factor %lld, blockOffset %lld, mid %lld, totLen %lld \n", m,k,factor,blockOffset,mid,totLen);
+
+ while (true) {
+ if(rng->relativeLong(ind++) % 2) {
+ if(mid == totLen)
+ break;
+ math::nd4j_swap(x[(blockOffset + beg) * ews], x[(blockOffset + mid++) * ews]);
+ } else {
+ if(beg == mid)
+ break;
+ }
+ ++beg;
+ }
+
+ // Fisher-Yates
+ while (beg < totLen) {
+ const Nd4jLong e = rng->relativeLong(ind++) % (beg + 1);
+ if(beg != e)
+ math::nd4j_swap(x[(blockOffset + beg) * ews], x[(blockOffset + e) * ews]);
+ ++beg;
+ }
+ }
+}
+
+
+//////////////////////////////////////////////////////////////////////////
+// Fisher-Yates shuffle
+template
+static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) {
+
+ for(Nd4jLong i = len-1; i > 0; --i) {
+ const Nd4jLong j = rng.relativeLong(ind++) % (i + 1);
+ if(i != j)
+ math::nd4j_swap(buff[i*ews], buff[j*ews]);
+ }
+}
+
+//////////////////////////////////////////////////////////////////////////
+template
+static void randomShuffle_(sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
+
+ const int firstDim = input.sizeAt(0);
+ int temp;
+
+ if(input.lengthOf() == 1 || firstDim == 1) {
+
+ if(!isInplace)
+ output.assign(input);
+ }
+ else if (shape::isCommonVector(input.shapeInfo(), temp)) {
+
+ NDArray* arr = &input;
+
+ if (!isInplace) {
+ output.assign(input);
+ arr = &output;
+ }
+
+ const Nd4jLong len = arr->lengthOf();
+
+ const int threadsPerBlock = MAX_NUM_THREADS;
+
+ int power = 0;
+ while ((len >> power) > threadsPerBlock)
+ ++power;
+
+ const int blocksPerGrid = 1 << power;
+ const int sharedMem = threadsPerBlock * input.sizeOfT() + 256;
+
+ PointersManager manager(context, "NDArray::randomShuffle cuda");
+
+ sd::graph::RandomGenerator* pRng = reinterpret_cast(manager.replicatePointer(&rng, sizeof(sd::graph::RandomGenerator)));
+
+ NDArray::prepareSpecialUse({arr}, {arr});
+ fisherYatesCuda<<getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power);
+ for (Nd4jLong j = 1, i = 1; j < blocksPerGrid; j += j, ++i)
+ mergeShuffleCuda<<getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power, i);
+ NDArray::registerSpecialUse({arr}, {arr});
+
+ manager.synchronize();
+
+ rng.rewindH((len + 1) * power);
+ }
+ else {
+
+ auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
+
+ if(isInplace) {
+
+ auto subArrsList = input.allTensorsAlongDimension(dimsToExclude);
+
+ // Fisher-Yates shuffle
+ for(int i = firstDim - 1; i > 0; --i) {
+ const int j = rng.relativeInt(i) % (i + 1);
+ if(i != j)
+ subArrsList.at(i)->swapUnsafe(*subArrsList.at(j));
+ }
+ }
+ else {
+
+ auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude);
+ auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude);
+
+ std::vector indices(firstDim);
+ std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1
+
+ // shuffle indices
+ fisherYates(rng, indices.data(), firstDim, 1, 0);
+
+ auto func = PRAGMA_THREADS_FOR {
+
+ for (auto i = start; i < stop; ++i)
+ subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i]));
+ };
+
+ samediff::Threads::parallel_for(func, 0, firstDim);
+ }
+
+ rng.rewindH(firstDim-1);
+ }
+}
+
+/////////////////////////////////////////////////////////////////////////
+void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
+ BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
+}
+
+// BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
+
+
+
+}
+}
+}
\ No newline at end of file
diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu
index 8d7f700dd..80e0e0858 100644
--- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu
+++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu
@@ -300,129 +300,6 @@ void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray
manager.synchronize();
}
- template
- static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) {
- auto tid = blockIdx.x * blockDim.x;
- auto step = blockDim.x * gridDim.x;
-
- for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
- int r = rng->relativeInt(i) % i;
- if (i != r) {
- const auto iOffset = shape::getIndexOffset(i, shape);
- const auto rOffset = shape::getIndexOffset(r, shape);
- T e0 = input[iOffset];
- T e1 = input[rOffset];
- //math::nd4j_swap(input(i), input(r));
- input[iOffset] = e1;
- input[rOffset] = e0;
- }
- }
- }
- template
- static __global__ void fillShuffleKernel(T* input, Nd4jLong const* inputShape, T* output, Nd4jLong const* outputShape, Nd4jLong firstDim, int* indices, sd::graph::RandomGenerator* rng) {
-
-// PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
- auto tid = blockIdx.x * blockDim.x;
- auto step = blockDim.x * gridDim.x;
-
- for(int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
- int r = rng->relativeInt(i) % i;
- output[shape::getIndexOffset(i, outputShape)] = input[shape::getIndexOffset(indices[r], inputShape)];
- if(i != r) {
- output[shape::getIndexOffset(r, outputShape)] = input[shape::getIndexOffset(indices[i], inputShape)];
-// output.p(r, input.e(indices[i]));
-// math::nd4j_swap(indices[i], indices[r]);
- atomicExch(&indices[i], indices[r]);
- }
- }
-
- }
- //////////////////////////////////////////////////////////////////////////
- template
- void randomShuffle_(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
-
- // check edge cases first
- int temp;
- const int firstDim = input.sizeAt(0);
- auto stream = context->getCudaStream();
- NDArray::prepareSpecialUse({&output}, {&input});
- if(input.lengthOf() == 1 || firstDim == 1) {
- if(!isInplace)
- output.assign(input);
- }
- else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) {
-
- // apply Fisher-Yates shuffle
- sd::graph::RandomGenerator* dRandom = nullptr;
- cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator));
- cudaMemcpy(dRandom, &rng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice);
- T* inputBuf = reinterpret_cast(input.specialBuffer());
- if(isInplace) {
- swapShuffleKernel<<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), firstDim, dRandom);
- }
- else {
- std::vector indices(firstDim);
- std::iota(indices.begin(), indices.end(), 0);
- cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), cudaMemcpyDeviceToDevice);
- //output.p(Nd4jLong(0), input.e(0));
- PointersManager pointersManager(context, "helper::randomShuffle_");
- int* indicesDev = reinterpret_cast(pointersManager.replicatePointer(indices.data(), indices.size() * sizeof(int)));
- T* outputBuf = reinterpret_cast(output.specialBuffer());
- fillShuffleKernel<<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), outputBuf, output.specialShapeInfo(), firstDim, indicesDev, dRandom);
- pointersManager.synchronize();
- }
-// rng.rewindH(firstDim - 1);
- cudaFree(dRandom);
- }
- else {
-
- // evaluate sub-arrays list of input array through all dimensions excluding first one
- std::vector dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
- auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
-
- // apply Fisher-Yates shuffle
- if(isInplace) {
- for(int i = firstDim - 1; i > 0; --i) {
- int r = rng.relativeInt(i) % i;
-
- if(i != r)
- subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r));
- }
- }
- else {
- // evaluate sub-arrays list of output array through all dimensions excluding first one
- auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
- std::vector indices(firstDim);
- std::iota(indices.begin(), indices.end(), 0);
- bool isZeroShuffled = false;
-
- for(int i = firstDim - 1; i > 0; --i) {
- int r = rng.relativeInt(i) % i;
- subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r]));
- if(r == 0)
- isZeroShuffled = true;
-
- if(i != r) {
- subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i]));
- math::nd4j_swap(indices[i], indices[r]);
- }
- }
- if(!isZeroShuffled)
- subArrsListOut.at(0)->assign(subArrsListIn.at(0));
- }
- rng.rewindH(firstDim-1);
- }
- NDArray::registerSpecialUse({&output}, {&input});
-
- }
-
- void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
- BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
- }
-
- BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
-
-
//////////////////////////////////////////////////////////////////////////
void eye(sd::LaunchContext * context, NDArray& output) {
diff --git a/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp
new file mode 100644
index 000000000..66b472252
--- /dev/null
+++ b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp
@@ -0,0 +1,278 @@
+/*******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+ // Created by Abdelrauf 2020
+
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "armcomputeUtils.h"
+
+
+namespace sd {
+namespace ops {
+namespace platforms {
+
+
+
+Arm_DataType getArmType ( const DataType &dType){
+ Arm_DataType ret;
+ switch (dType){
+ case HALF :
+ ret = Arm_DataType::F16;
+ break;
+ case FLOAT32 :
+ ret = Arm_DataType::F32;
+ break;
+ case DOUBLE :
+ ret = Arm_DataType::F64;
+ break;
+ case INT8 :
+ ret = Arm_DataType::S8;
+ break;
+ case INT16 :
+ ret = Arm_DataType::S16;
+ break;
+ case INT32 :
+ ret = Arm_DataType::S32;
+ break;
+ case INT64 :
+ ret = Arm_DataType::S64;
+ break;
+ case UINT8 :
+ ret = Arm_DataType::U8;
+ break;
+ case UINT16 :
+ ret = Arm_DataType::U16;
+ break;
+ case UINT32 :
+ ret = Arm_DataType::U32;
+ break;
+ case UINT64 :
+ ret = Arm_DataType::U64;
+ break;
+ case BFLOAT16 :
+ ret = Arm_DataType::BFLOAT16;
+ break;
+ default:
+ ret = Arm_DataType::UNKNOWN;
+ };
+
+ return ret;
+}
+bool isArmcomputeFriendly(const NDArray& arr) {
+ auto dType = getArmType(arr.dataType());
+ int rank = (int)(arr.rankOf());
+ return dType != Arm_DataType::UNKNOWN &&
+ rank<=arm_compute::MAX_DIMS &&
+ arr.ordering() == 'c' &&
+ arr.ews()==1 &&
+ shape::strideDescendingCAscendingF(arr.shapeInfo()) == true;
+}
+
+Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases,sd::DataType ndArrayType, arm_compute::DataLayout layout) {
+ constexpr int numChannels = 1;
+ auto dType = getArmType(ndArrayType);
+
+ Arm_TensorShape shape;
+ shape.set_num_dimensions(rank);
+ for (int i = 0, j = rank - 1; i < rank; i++, j--) {
+ shape[i] = static_cast(bases[j]);
+ }
+ // fill the rest unused with 1
+ for (int i = rank; i < arm_compute::MAX_DIMS; i++) {
+ shape[i] = 1;
+ }
+
+ return Arm_TensorInfo(shape, numChannels, dType, layout);
+}
+
+Arm_TensorInfo getArmTensorInfo(const NDArray& arr,
+ arm_compute::DataLayout layout) {
+ auto dType = getArmType(arr.dataType());
+
+ //
+ constexpr int numChannels = 1;
+ int rank = (int)(arr.rankOf());
+ auto bases = arr.shapeOf();
+ auto arrStrides = arr.stridesOf();
+
+ // https://arm-software.github.io/ComputeLibrary/v20.05/_dimensions_8h_source.xhtml
+ // note: underhood it is stored as std::array _id;
+ // TensorShape is derived from Dimensions
+ // as well as Strides : public Dimensions
+ Arm_TensorShape shape;
+ Arm_Strides strides;
+ shape.set_num_dimensions(rank);
+ strides.set_num_dimensions(rank);
+ size_t element_size = arm_compute::data_size_from_type(dType);
+ for (int i = 0, j = rank - 1; i < rank; i++, j--) {
+ shape[i] = static_cast(bases[j]);
+ strides[i] = static_cast(arrStrides[j]) * element_size;
+ }
+ // fill the rest unused with 1
+ for (int i = rank; i < arm_compute::MAX_DIMS; i++) {
+ shape[i] = 1;
+ }
+ size_t total_size;
+ size_t size_ind = rank - 1;
+ total_size = shape[size_ind] * strides[size_ind];
+
+ Arm_TensorInfo info;
+ info.init(shape, numChannels, dType, strides, 0, total_size);
+ info.set_data_layout(layout);
+
+ return info;
+}
+
+Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout) {
+ // - Ownership of the backing memory is not transferred to the tensor itself.
+ // - The tensor mustn't be memory managed.
+ // - Padding requirements should be accounted by the client code.
+ // In other words, if padding is required by the tensor after the function
+ // configuration step, then the imported backing memory should account for it.
+ // Padding can be checked through the TensorInfo::padding() interface.
+
+ // Import existing pointer as backing memory
+ auto info = getArmTensorInfo(arr, layout);
+ Arm_Tensor tensor;
+ tensor.allocator()->init(info);
+ void* buff = (void*)arr.buffer();
+ tensor.allocator()->import_memory(buff);
+ return tensor;
+}
+
+void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output) {
+ //only for C order
+ //only for C order
+ if (output.ordering() != 'c') return;
+ auto shapeInfo = output.shapeInfo();
+ auto bases = &(shapeInfo[1]);
+ Nd4jLong rank = shapeInfo[0];
+ auto strides = output.stridesOf();
+ int width = bases[rank - 1];
+ uint8_t* outputBuffer = (uint8_t*)output.buffer();
+ size_t offset = 0;
+ arm_compute::Window window;
+ arm_compute::Iterator tensor_it(&inTensor, window);
+
+ int element_size = inTensor.info()->element_size();
+ window.use_tensor_dimensions(inTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY);
+
+// if (output.ews() == 1) {
+ auto copySize = width * element_size;
+ auto dest = outputBuffer;
+ arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
+ {
+ auto src = tensor_it.ptr();
+ memcpy(dest, src, copySize);
+ dest += copySize;
+ },
+ tensor_it);
+ // }
+ // else {
+ // Nd4jLong coords[MAX_RANK] = {};
+ // if(strides[rank-1]!=1){
+ // throw std::runtime_error( "not implemented for subarrays whose last stride is not 1");
+ // //TODO: implement to work with all subarrays properly
+ // }
+ // arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
+ // {
+ // auto src = tensor_it.ptr();
+ // auto dest = outputBuffer + offset * element_size;
+ // memcpy(dest, src, width * element_size);
+ // offset = sd::inc_coords(bases, strides, coords, offset, rank, 1);
+ // },
+ // tensor_it);
+ // }
+}
+
+void copyToTensor(const NDArray& input, Arm_Tensor& outTensor) {
+ //only for C order
+ if (input.ordering() != 'c') return;
+ auto shapeInfo = input.shapeInfo();
+ auto bases = &(shapeInfo[1]);
+ Nd4jLong rank = shapeInfo[0];
+ auto strides = input.stridesOf();
+ uint8_t *inputBuffer = (uint8_t*)input.buffer();
+ int width = bases[rank - 1];
+ size_t offset = 0;
+ arm_compute::Window window;
+ arm_compute::Iterator tensor_it(&outTensor, window);
+ int element_size = outTensor.info()->element_size();
+
+ window.use_tensor_dimensions(outTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY);
+
+// if (input.ews() == 1) {
+
+ auto copySize = width * element_size;
+ auto src = inputBuffer;
+ arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
+ {
+ auto dest = tensor_it.ptr();
+ memcpy(dest,src, copySize);
+ src += copySize;
+ },
+ tensor_it);
+// }
+// else {
+// Nd4jLong coords[MAX_RANK] = {};
+// if(strides[rank-1]!=1){
+// throw std::runtime_error( "not implemented for subarrays whose last stride is not 1");
+// //TODO: implement to work with all subarrays properly
+// }
+// arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
+// {
+// auto dest = tensor_it.ptr();
+// auto src = inputBuffer + offset * element_size;
+// offset = sd::inc_coords(bases, strides, coords, offset, rank, 1);
+// },
+// tensor_it);
+// }
+}
+
+
+// armcompute should be built with debug option
+void print_tensor(Arm_ITensor& tensor, const char* msg) {
+ auto info = tensor.info();
+ auto padding = info->padding();
+ std::cout << msg << "\ntotal: " << info->total_size() << "\n";
+
+ for (int i = 0; i < arm_compute::MAX_DIMS; i++) {
+ std::cout << info->dimension(i) << ",";
+ }
+ std::cout << std::endl;
+ for (int i = 0; i < arm_compute::MAX_DIMS; i++) {
+ std::cout << info->strides_in_bytes()[i] << ",";
+ }
+ std::cout << "\npadding: l " << padding.left << ", r " << padding.right
+ << ", t " << padding.top << ", b " << padding.bottom << std::endl;
+
+#ifdef ARM_COMPUTE_ASSERTS_ENABLED
+ //note it did not print correctly fro NHWC
+ std::cout << msg << ":\n";
+ tensor.print(std::cout);
+ std::cout << std::endl;
+#endif
+}
+
+}
+}
+}
diff --git a/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h
new file mode 100644
index 000000000..72a4e6e89
--- /dev/null
+++ b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h
@@ -0,0 +1,133 @@
+/*******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+
+#ifndef DEV_TESTSARMCOMPUTEUTILS_H
+#define DEV_TESTSARMCOMPUTEUTILS_H
+
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+using namespace samediff;
+
+
+namespace sd {
+ namespace ops {
+ namespace platforms {
+
+ using Arm_DataType = arm_compute::DataType;
+ using Arm_Tensor = arm_compute::Tensor;
+ using Arm_ITensor = arm_compute::ITensor;
+ using Arm_TensorInfo = arm_compute::TensorInfo;
+ using Arm_TensorShape = arm_compute::TensorShape;
+ using Arm_Strides = arm_compute::Strides;
+ /**
+ * Here we actually declare our platform helpers
+ */
+
+
+ DECLARE_PLATFORM(maxpool2d, ENGINE_CPU);
+
+ DECLARE_PLATFORM(avgpool2d, ENGINE_CPU);
+
+ //utils
+ Arm_DataType getArmType(const sd::DataType& dType);
+
+ Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases, sd::DataType ndArrayType, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN);
+
+ Arm_TensorInfo getArmTensorInfo(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN);
+
+ Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN);
+
+ void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output);
+ void copyToTensor(const NDArray& input, Arm_Tensor& outTensor);
+ void print_tensor(Arm_ITensor& tensor, const char* msg);
+ bool isArmcomputeFriendly(const NDArray& arr);
+
+
+ template
+ class ArmFunction {
+ public:
+
+ template
+ void configure(NDArray *input , NDArray *output, arm_compute::DataLayout layout, Args&& ...args) {
+
+ auto inInfo = getArmTensorInfo(*input, layout);
+ auto outInfo = getArmTensorInfo(*output, layout);
+ in.allocator()->init(inInfo);
+ out.allocator()->init(outInfo);
+ armFunction.configure(&in,&out,std::forward(args) ...);
+ if (in.info()->has_padding()) {
+ //allocate and copy
+ in.allocator()->allocate();
+ //copy
+ copyToTensor(*input, in);
+
+ }
+ else {
+ //import buffer
+ void* buff = input->buffer();
+ in.allocator()->import_memory(buff);
+ }
+ if (out.info()->has_padding()) {
+ //store pointer to our array to copy after run
+ out.allocator()->allocate();
+ outNd = output;
+ }
+ else {
+ //import
+ void* buff = output->buffer();
+ out.allocator()->import_memory(buff);
+ }
+
+ }
+
+ void run() {
+ armFunction.run();
+ if (outNd) {
+ copyFromTensor(out, *outNd);
+ }
+ }
+
+ private:
+ Arm_Tensor in;
+ Arm_Tensor out;
+ NDArray *outNd=nullptr;
+ F armFunction{};
+ };
+ }
+ }
+}
+
+
+
+#endif //DEV_TESTSARMCOMPUTEUTILS_H
diff --git a/libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp
new file mode 100644
index 000000000..d8413104d
--- /dev/null
+++ b/libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp
@@ -0,0 +1,106 @@
+/*******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+ // Created by Abdelrauf (rauf@konduit.ai) 2020
+
+#include
+#include
+#include
+#include
+
+
+#include "armcomputeUtils.h"
+
+
+namespace sd {
+namespace ops {
+namespace platforms {
+
+
+//////////////////////////////////////////////////////////////////////////
+PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
+
+ auto input = INPUT_VARIABLE(0);
+ auto output = OUTPUT_VARIABLE(0);
+
+ // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
+
+ const auto kH = INT_ARG(0);
+ const auto kW = INT_ARG(1);
+ const auto sH = INT_ARG(2);
+ const auto sW = INT_ARG(3);
+ auto pH = INT_ARG(4);
+ auto pW = INT_ARG(5);
+ const auto dH = INT_ARG(6);
+ const auto dW = INT_ARG(7);
+ const auto paddingMode = INT_ARG(8);
+ const auto extraParam0 = INT_ARG(9);
+ const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
+
+ REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D ARMCOMPUTE op: input should have rank of 4, but got %i instead", input->rankOf());
+ REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D ARMCOMPUTE op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
+
+ bool exclude_padding= (extraParam0 == 0) ? true : false;
+
+ auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC;
+
+ // Calculate individual paddings
+ unsigned int pad_left, pad_top, pad_right, pad_bottom;
+ int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
+ int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
+ ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
+
+ if(paddingMode){
+ ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
+ }
+ pad_left = pW;
+ pad_top = pH;
+ pad_right = (oW - 1) * sW - iW + kW - pW ;
+ pad_bottom = (oH - 1) * sH - iH + kH - pH ;
+
+#if 0
+ nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH
+ , pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0);
+#endif
+ auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR);
+ auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::AVG, arm_compute::Size2D(kW, kH), dataLayout, poolPad, exclude_padding);
+ ArmFunction pool;
+ pool.configure(input,output, dataLayout, poolInfo);
+
+ pool.run(); // run function
+
+ return Status::OK();
+}
+
+//////////////////////////////////////////////////////////////////////////
+PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
+ auto input = INPUT_VARIABLE(0);
+ auto output = OUTPUT_VARIABLE(0);
+ const int dH = INT_ARG(6);
+ const int dW = INT_ARG(7);
+ // Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32
+ auto dTypeInput = getArmType(input->dataType());
+ auto dTypeOutput = getArmType(output->dataType());
+ bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output)
+ && (dTypeInput ==Arm_DataType::F32)
+ && (dTypeOutput ==Arm_DataType::F32);
+ return is_supported;
+}
+
+
+
+}
+}
+}
diff --git a/libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp
new file mode 100644
index 000000000..cd6779628
--- /dev/null
+++ b/libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp
@@ -0,0 +1,106 @@
+/*******************************************************************************
+ * Copyright (c) 2019 Konduit K.K.
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+ // Created by Abdelrauf 2020
+
+
+#include
+#include
+#include
+#include
+
+
+#include "armcomputeUtils.h"
+
+
+namespace sd {
+namespace ops {
+namespace platforms {
+
+
+//////////////////////////////////////////////////////////////////////////
+PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
+
+ auto input = INPUT_VARIABLE(0);
+ auto output = OUTPUT_VARIABLE(0);
+
+ REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D ARMCOMPUTE OP: input array should have rank of 4, but got %i instead", input->rankOf());
+
+ // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
+ const int kH = INT_ARG(0);
+ const int kW = INT_ARG(1);
+ const int sH = INT_ARG(2);
+ const int sW = INT_ARG(3);
+ int pH = INT_ARG(4);
+ int pW = INT_ARG(5);
+ const int dH = INT_ARG(6);
+ const int dW = INT_ARG(7);
+ const int paddingMode = INT_ARG(8);
+ // const int extraParam0 = INT_ARG(9);
+ const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW
+
+ REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
+
+ auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC;
+
+ // Calculate individual paddings
+ unsigned int pad_left, pad_top, pad_right, pad_bottom;
+ int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
+ int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
+ ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
+
+ if(paddingMode){
+ ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
+ }
+ pad_left = pW;
+ pad_top = pH;
+ pad_right = (oW - 1) * sW - iW + kW - pW ;
+ pad_bottom = (oH - 1) * sH - iH + kH - pH ;
+#if 0
+ nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH
+ , pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0);
+#endif
+
+ auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR);
+ auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::MAX, arm_compute::Size2D(kW, kH), dataLayout, poolPad);
+ ArmFunction pool;
+
+ pool.configure(input,output, dataLayout, poolInfo);
+
+ pool.run(); // run function
+
+ return Status::OK();
+}
+
+//////////////////////////////////////////////////////////////////////////
+PLATFORM_CHECK(maxpool2d, ENGINE_CPU) {
+ auto input = INPUT_VARIABLE(0);
+ auto output = OUTPUT_VARIABLE(0);
+ const int dH = INT_ARG(6);
+ const int dW = INT_ARG(7);
+ // Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32
+ auto dTypeInput = getArmType(input->dataType());
+ auto dTypeOutput = getArmType(output->dataType());
+ bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output)
+ && (dTypeInput ==Arm_DataType::F32)
+ && (dTypeOutput ==Arm_DataType::F32);
+ return is_supported;
+}
+
+
+
+}
+}
+}
diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h
index ea52e9ba0..aca6fec6f 100644
--- a/libnd4j/include/ops/ops.h
+++ b/libnd4j/include/ops/ops.h
@@ -3963,9 +3963,6 @@ namespace simdOps {
}
#endif
-#ifndef __clang__
-#pragma omp declare simd uniform(extraParamsRef)
-#endif
op_def static Y merge(X old, X opOutput, X *extraParamsRef) {
return update(old, opOutput, extraParamsRef);
}
diff --git a/libnd4j/pi_build.sh b/libnd4j/pi_build.sh
new file mode 100755
index 000000000..f96c3f1f1
--- /dev/null
+++ b/libnd4j/pi_build.sh
@@ -0,0 +1,185 @@
+#!/bin/bash
+TARGET=armv7-a
+BLAS_TARGET_NAME=ARMV7
+ARMCOMPUTE_TARGET=armv7a
+#BASE_DIR=${HOME}/pi
+#https://stackoverflow.com/questions/59895/how-to-get-the-source-directory-of-a-bash-script-from-within-the-script-itself
+SOURCE="${BASH_SOURCE[0]}"
+ARMCOMPUTE_DEBUG=1
+LIBND4J_BUILD_MODE=Release
+while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
+ DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )"
+ SOURCE="$(readlink "$SOURCE")"
+ [[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
+done
+BASE_DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )"
+CMAKE=cmake #/snap/bin/cmake
+
+mkdir -p ${BASE_DIR}/helper_bin/
+
+CROSS_COMPILER_URL=https://sourceforge.net/projects/raspberry-pi-cross-compilers/files/Raspberry%20Pi%20GCC%20Cross-Compiler%20Toolchains/Buster/GCC%208.3.0/Raspberry%20Pi%203A%2B%2C%203B%2B%2C%204/cross-gcc-8.3.0-pi_3%2B.tar.gz/download
+CROSS_COMPILER_DIR=${BASE_DIR}/helper_bin/cross_compiler
+
+SCONS_LOCAL_URL=http://prdownloads.sourceforge.net/scons/scons-local-3.1.1.tar.gz
+SCONS_LOCAL_DIR=${BASE_DIR}/helper_bin/scons_local
+
+THIRD_PARTY=${BASE_DIR}/third_party_libs
+
+ARMCOMPUTE_GIT_URL=https://github.com/ARM-software/ComputeLibrary.git
+ARMCOMPUTE_TAG=v20.05
+ARMCOMPUTE_DIR=${THIRD_PARTY}/arm_compute_dir
+
+OPENBLAS_GIT_URL="https://github.com/xianyi/OpenBLAS.git"
+OPENBLAS_DIR=${THIRD_PARTY}/OpenBLAS
+
+
+LIBND4J_SRC_DIR=${BASE_DIR}
+
+LIBND4J_BUILD_DIR=${BASE_DIR}/build_pi
+
+#for some downloads
+XRTACT_STRIP="--strip-components=1"
+
+HAS_ARMCOMPUTE=1
+mkdir -p ${BASE_DIR}
+mkdir -p ${THIRD_PARTY}
+
+#change directory to base
+cd $BASE_DIR
+
+function message {
+ echo "BUILDER:::: ${@}"
+}
+
+
+function check_requirements {
+ for i in "${@}"
+ do
+ if [ ! -e "$i" ]; then
+ message "missing: ${i}"
+ exit -2
+ fi
+ done
+}
+
+function download_extract {
+ #$1 is url #2 is dir $3 is extract argument
+ if [ ! -f ${2}_file ]; then
+ message "download"
+ wget --quiet --show-progress -O ${2}_file ${1}
+ fi
+
+ message "extract"
+ #extract
+ mkdir -p ${2}
+ command="tar -xzf ${2}_file --directory=${2} ${3} "
+ message $command
+ $command
+
+ check_requirements "${2}"
+}
+
+function git_check {
+ #$1 is url #$2 is dir #$3 is tag or branch if optional
+ command="git clone --quiet ${1} ${2}"
+ message "$command"
+ $command
+ if [ -n "$3" ]; then
+ cd ${2}
+ command="git checkout ${3}"
+ message "$command"
+ $command
+ cd ${BASE_DIR}
+ fi
+ check_requirements "${2}"
+}
+
+
+if [ ! -d ${CROSS_COMPILER_DIR} ]; then
+ #out file
+ message "download CROSS_COMPILER"
+ download_extract ${CROSS_COMPILER_URL} ${CROSS_COMPILER_DIR} ${XRTACT_STRIP}
+fi
+
+#useful exports
+export PI_FOLDER=${CROSS_COMPILER_DIR}
+export RPI_BIN=${PI_FOLDER}/bin/arm-linux-gnueabihf
+export PI_SYS_ROOT=${PI_FOLDER}/arm-linux-gnueabihf/libc
+export LD_LIBRARY_PATH=${PI_FOLDER}/lib:$LD_LIBRARY_PATH
+export CC=${RPI_BIN}-gcc
+export FC=${RPI_BIN}-gfortran
+export CXX=${RPI_BIN}-g++
+export CPP=${RPI_BIN}-cpp
+export RANLIB=${RPI_BIN}-gcc-ranlib
+export LD="${RPI_BIN}-ld"
+export AR="${RPI_BIN}-ar"
+
+
+#lets build OpenBlas
+if [ ! -d "${OPENBLAS_DIR}" ]; then
+ message "download OpenBLAS"
+ git_check "${OPENBLAS_GIT_URL}" "${OPENBLAS_DIR}"
+fi
+
+if [ ! -f "${THIRD_PARTY}/lib/libopenblas.so" ]; then
+ message "build and install OpenBLAS"
+ cd ${OPENBLAS_DIR}
+
+ command="make TARGET=${BLAS_TARGET_NAME} HOSTCC=gcc CC=${CC} USE_THREAD=0 NOFORTRAN=1 CFLAGS=--sysroot=${PI_SYS_ROOT} LDFLAGS=\"-L${PI_SYS_ROOT}/../lib/ -lm\" &>/dev/null"
+ message $command
+ eval $command
+ message "install it"
+ command="make PREFIX=${THIRD_PARTY} install"
+ message $command
+ $command
+ cd $BASE_DIR
+
+fi
+check_requirements ${THIRD_PARTY}/lib/libopenblas.so
+
+
+
+if [ ! -d ${SCONS_LOCAL_DIR} ]; then
+ #out file
+ message "download Scons local"
+ download_extract ${SCONS_LOCAL_URL} ${SCONS_LOCAL_DIR}
+fi
+check_requirements ${SCONS_LOCAL_DIR}/scons.py
+
+
+if [ ! -d "${ARMCOMPUTE_DIR}" ]; then
+ message "download ArmCompute Source"
+ git_check ${ARMCOMPUTE_GIT_URL} "${ARMCOMPUTE_DIR}" "tags/${ARMCOMPUTE_TAG}"
+fi
+
+#build armcompute
+if [ ! -f "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" ]; then
+message "build arm compute"
+cd ${ARMCOMPUTE_DIR}
+command="CC=gcc CXX=g++ python3 ${SCONS_LOCAL_DIR}/scons.py Werror=1 -j$(nproc) toolchain_prefix=${RPI_BIN}- debug=${ARMCOMPUTE_DEBUG} neon=1 opencl=0 extra_cxx_flags=-fPIC os=linux build=cross_compile arch=${ARMCOMPUTE_TARGET} &>/dev/null"
+message $command
+eval $command
+cd ${BASE_DIR}
+fi
+check_requirements "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" "${ARMCOMPUTE_DIR}/build/libarm_compute_core-static.a"
+
+
+
+message "build cmake for LIBND4J. output: ${LIBND4J_BUILD_DIR}"
+
+TOOLCHAIN=${LIBND4J_SRC_DIR}/cmake/rpi.cmake
+cmake_cmd="${CMAKE} -G \"Unix Makefiles\" -B${LIBND4J_BUILD_DIR} -S${LIBND4J_SRC_DIR} -DCMAKE_BUILD_TYPE=${LIBND4J_BUILD_MODE} -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN} -DCMAKE_VERBOSE_MAKEFILE:BOOL=ON -DSD_ALL_OPS=true -DSD_CPU=true -DSD_LIBRARY_NAME=nd4jcpu -DSD_BUILD_TESTS=ON -DSD_ARM_BUILD=true -DOPENBLAS_PATH=${THIRD_PARTY} -DSD_ARCH=${TARGET} -DARMCOMPUTE_ROOT=${ARMCOMPUTE_DIR} -DHELPERS_armcompute=${HAS_ARMCOMPUTE}"
+message $cmake_cmd
+eval $cmake_cmd
+
+#build
+message "lets build"
+
+cd ${LIBND4J_BUILD_DIR}
+make -j $(nproc)
+
+
+
+
+
+
diff --git a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt
index 563bf58f6..9478f6fe2 100644
--- a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt
+++ b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt
@@ -52,14 +52,19 @@ elseif(WIN32)
set(CMAKE_CXX_FLAGS " -fPIC")
endif()
else()
- set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
set(CMAKE_CXX_FLAGS " -fPIC")
+ set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
+ IF(${SD_ARCH} MATCHES "arm*")
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=${SD_ARCH}")
+ else()
+ set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
+
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native")
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native")
endif()
-
+ endif()
if (SD_CPU AND SD_SANITIZE)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address")
else()
@@ -130,7 +135,7 @@ if (SD_CPU)
endif()
add_executable(runtests ${TEST_SOURCES})
- target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} gtest gtest_main)
+ target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} ${ARMCOMPUTE_LIBRARIES} gtest gtest_main)
elseif(SD_CUDA)
add_executable(runtests ${TEST_SOURCES})
diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp
index 169c51124..39277cd87 100644
--- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp
+++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp
@@ -1113,7 +1113,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) {
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
-
+#if 0
+ exp.printIndexedBuffer("Expected");
+ z->printIndexedBuffer("Z");
+#endif
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@@ -1132,7 +1135,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) {
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
-
+#if 0
+ exp.printIndexedBuffer("Expected");
+ z->printIndexedBuffer("Z");
+#endif
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
@@ -1151,7 +1157,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) {
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0);
-
+#if 0
+ exp.printIndexedBuffer("Expected");
+ z->printIndexedBuffer("Z");
+#endif
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
@@ -1204,7 +1213,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) {
auto* output = results.at(0);
ASSERT_EQ(Status::OK(), results.status());
-
+#if 0
+ expOutput.printIndexedBuffer("expOutput");
+ output->printIndexedBuffer("output");
+#endif
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
}
diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp
index 5f1aefe36..beccc1aae 100644
--- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp
+++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp
@@ -244,7 +244,8 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) {
#ifdef _RELEASE
TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
// [2,1,135079944,1,1,8192,1,99]
- auto initial = NDArrayFactory::create('c', {1, 135079944});
+ constexpr int sizeX= 10*1000*1000;
+ auto initial = NDArrayFactory::create('c', {1, sizeX});
initial = 1.0f;
auto exp = initial.dup();
auto neg = initial.like();
@@ -254,7 +255,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
auto enc_result = enc.evaluate({&initial}, {0.5f});
auto encoded = enc_result.at(1);
- ASSERT_EQ(135079944 + 4, encoded->lengthOf());
+ ASSERT_EQ(sizeX + 4, encoded->lengthOf());
ASSERT_NE(exp, initial);
/*
for (int e = 0; e < initial.lengthOf(); e++) {
@@ -419,3 +420,4 @@ TEST_F(DeclarableOpsTests19, test_squeeze_1) {
auto status = op.execute({&x}, {&e}, {axis});
ASSERT_EQ(Status::OK(), status);
}
+
diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp
index 04bb54a61..c68392da1 100644
--- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp
+++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp
@@ -1557,8 +1557,6 @@ TEST_F(DeclarableOpsTests5, trace_test1) {
// exp.printIndexedBuffer("EXP TRACE");
// output->printIndexedBuffer("OUT TRACE");
ASSERT_TRUE(exp.equalsTo(output));
-
-
}
//////////////////////////////////////////////////////////////////////
@@ -1575,8 +1573,6 @@ TEST_F(DeclarableOpsTests5, trace_test2) {
ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
-
-
}
//////////////////////////////////////////////////////////////////////
@@ -1593,8 +1589,6 @@ TEST_F(DeclarableOpsTests5, trace_test3) {
ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
-
-
}
//////////////////////////////////////////////////////////////////////
@@ -1611,8 +1605,6 @@ TEST_F(DeclarableOpsTests5, trace_test4) {
ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
-
-
}
//////////////////////////////////////////////////////////////////////
@@ -1629,8 +1621,6 @@ TEST_F(DeclarableOpsTests5, trace_test5) {
ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
-
-
}
//////////////////////////////////////////////////////////////////////
@@ -1638,22 +1628,15 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) {
auto input = NDArrayFactory::create('c', {2, 2, 2});
input.linspace(1);
+ NDArray exp1 = input.dup();
+ NDArray exp2('c',{2,2,2}, {5,6,7,8, 1,2,3,4}, sd::DataType::DOUBLE);
sd::ops::random_shuffle op;
auto results = op.evaluate({&input});
auto output = results.at(0);
- bool haveZeros = false;
- for(int i = 0; i < output->lengthOf(); ++i)
- if(output->e(i) == (float)0.)
- haveZeros = true;
-
ASSERT_EQ(Status::OK(), results.status());
- ASSERT_TRUE(input.isSameShape(output));
- ASSERT_TRUE(!input.equalsTo(output));
- ASSERT_TRUE(!haveZeros);
-
-
+ ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2));
}
//////////////////////////////////////////////////////////////////////
@@ -1661,16 +1644,14 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test2) {
auto input = NDArrayFactory::create('c', {1, 3, 2});
input.linspace(1);
+ NDArray exp1 = input.dup();
sd::ops::random_shuffle op;
auto results = op.evaluate({&input});
auto output = results.at(0);
ASSERT_EQ(Status::OK(), results.status());
- ASSERT_TRUE(input.isSameShape(output));
- ASSERT_TRUE(input.equalsTo(output));
-
-
+ ASSERT_TRUE(output->equalsTo(exp1));
}
//////////////////////////////////////////////////////////////////////
@@ -1678,129 +1659,132 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) {
auto input = NDArrayFactory::create('c', {3, 2, 1});
input.linspace(1);
+ NDArray exp1 = input.dup();
+ NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE);
+ NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE);
+ NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE);
+ NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE);
+ NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE);
sd::ops::random_shuffle op;
- auto results = op.evaluate({&input});
- auto output = results.at(0);
-
- bool haveZeros = false;
- for(int i = 0; i < output->lengthOf(); ++i)
- if(output->e(i) == (float)0.)
- haveZeros = true;
-
- ASSERT_EQ(Status::OK(), results.status());
- ASSERT_TRUE(input.isSameShape(output));
- ASSERT_TRUE(!input.equalsTo(output));
- ASSERT_TRUE(!haveZeros);
-
-
-}
-//////////////////////////////////////////////////////////////////////
-TEST_F(DeclarableOpsTests5, random_shuffle_test04) {
- auto input = NDArrayFactory::create('c', {4});
- input.linspace(1);
-
- sd::ops::random_shuffle op;
- //NDArray* output;
auto results = op.evaluate({&input}, {}, {}, {}, {}, true);
+
ASSERT_EQ(Status::OK(), results.status());
- auto output = &input; //results.at(0);
- bool haveZeros = false;
- for(int i = 0; i < output->lengthOf(); ++i)
- if(output->e(i) == (float)0.)
- haveZeros = true;
-
- ASSERT_TRUE(input.isSameShape(output));
- //ASSERT_TRUE(!input.equalsTo(output));
- ASSERT_TRUE(!haveZeros);
-
-
+ ASSERT_TRUE(input.equalsTo(exp1) || input.equalsTo(exp2) || input.equalsTo(exp3)
+ || input.equalsTo(exp4) || input.equalsTo(exp5) || input.equalsTo(exp6));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
- auto input = NDArrayFactory::create('c', {4});
+
+ auto input = NDArrayFactory::create('c', {3, 2, 1});
input.linspace(1);
+ NDArray exp1 = input.dup();
+ NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE);
+ NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE);
+ NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE);
+ NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE);
+ NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE);
sd::ops::random_shuffle op;
- //NDArray* output;
auto results = op.evaluate({&input});
- ASSERT_EQ(Status::OK(), results.status());
auto output = results.at(0);
- bool haveZeros = false;
- for(int i = 0; i < output->lengthOf(); ++i)
- if(output->e(i) == (float)0.)
- haveZeros = true;
-
- ASSERT_TRUE(input.isSameShape(output));
- //ASSERT_TRUE(!input.equalsTo(output));
- ASSERT_TRUE(!haveZeros);
-
+ ASSERT_EQ(Status::OK(), results.status());
+ ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2) || output->equalsTo(exp3)
+ || output->equalsTo(exp4) || output->equalsTo(exp5) || output->equalsTo(exp6));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test5) {
-
- auto input = NDArrayFactory::create('c', {4,1});
+ auto input = NDArrayFactory::create('c', {4});
input.linspace(1);
sd::ops::random_shuffle op;
- auto results = op.evaluate({&input});
+ auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
auto output = results.at(0);
-
- bool haveZeros = false;
- for(int i = 0; i < output->lengthOf(); ++i)
- if(output->e(i) == (float)0.)
- haveZeros = true;
+ // output->printBuffer();
ASSERT_EQ(Status::OK(), results.status());
- ASSERT_TRUE(input.isSameShape(output));
- ASSERT_TRUE(!input.equalsTo(output));
- ASSERT_TRUE(!haveZeros);
-
+ // ASSERT_TRUE(!output->equalsTo(input));
+ bool hasDublicates = false;
+ for(int i = 0; i < output->lengthOf() - 1; ++i)
+ for(int j = i+1; j < output->lengthOf(); ++j)
+ if(output->t(i) == output->t(j)) {
+ hasDublicates = true;
+ i = output->lengthOf();
+ break;
+ }
+ ASSERT_TRUE(!hasDublicates);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test6) {
-
- auto input = NDArrayFactory::create('c', {4,1,1});
+ auto input = NDArrayFactory::create('c', {4,1,1});
input.linspace(1);
sd::ops::random_shuffle op;
- auto results = op.evaluate({&input});
+ auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
auto output = results.at(0);
- bool haveZeros = false;
- for(int i = 0; i < output->lengthOf(); ++i)
- if(output->e(i) == (float)0.)
- haveZeros = true;
-
ASSERT_EQ(Status::OK(), results.status());
- ASSERT_TRUE(input.isSameShape(output));
- ASSERT_TRUE(!input.equalsTo(output));
- ASSERT_TRUE(!haveZeros);
-
+ // ASSERT_TRUE(!output->equalsTo(input));
+ bool hasDublicates = false;
+ for(int i = 0; i < output->lengthOf() - 1; ++i)
+ for(int j = i+1; j < output->lengthOf(); ++j)
+ if(output->t(i) == output->t(j)) {
+ hasDublicates = true;
+ i = output->lengthOf();
+ break;
+ }
+ ASSERT_TRUE(!hasDublicates);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test7) {
-
- auto input = NDArrayFactory::create('c', {1,4});
+ auto input = NDArrayFactory::create('c', {16010});
input.linspace(1);
- auto exp = NDArrayFactory::create('c', {1,4}, {1, 2, 3, 4});
sd::ops::random_shuffle op;
- auto results = op.evaluate({&input});
+ auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
auto output = results.at(0);
-
+ // output->printBuffer();
ASSERT_EQ(Status::OK(), results.status());
- ASSERT_TRUE(input.isSameShape(output));
- ASSERT_TRUE(input.equalsTo(output));
+ ASSERT_TRUE(!output->equalsTo(input));
+ auto vec1 = input.getBufferAsVector();
+ auto vec2 = output->getBufferAsVector();
+ std::sort(vec2.begin(), vec2.end());
+ ASSERT_TRUE(std::equal(vec1.begin(), vec1.end(), vec2.begin()));
+}
+//////////////////////////////////////////////////////////////////////
+TEST_F(DeclarableOpsTests5, random_shuffle_test8) {
+ auto input = NDArrayFactory::create('c', {1,4,1});
+ input.linspace(1);
+ NDArray inCopy = input.dup();
+
+ sd::ops::random_shuffle op;
+ auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
+ ASSERT_EQ(Status::OK(), results.status());
+ ASSERT_TRUE(input.equalsTo(inCopy));
+
+}
+
+TEST_F(DeclarableOpsTests5, random_shuffle_test9) {
+
+ auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4});
+ auto z = x.ulike();
+
+ sd::ops::random_shuffle op;
+ auto status = op.execute({&x}, {&z});
+ ASSERT_EQ(Status::OK(), status);
+
+ auto vec = z.getBufferAsVector();
+ std::sort(vec.begin(), vec.end());
+ ASSERT_EQ(std::vector({1, 2, 3, 4}), vec);
}
////////////////////////////////////////////////////////////////////////////////////////
diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp
index 949b43d25..f2bd393e4 100644
--- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp
+++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp
@@ -251,11 +251,10 @@ TEST_F(DeclarableOpsTests9, concat_test1) {
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto output = result.at(0);
+ // output->printCurrentBuffer(false);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
-
-
}
////////////////////////////////////////////////////////////////////////////////
diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp
index a8f45cc48..e07a0496d 100644
--- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp
+++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp
@@ -317,7 +317,7 @@ void fill_random(sd::NDArray& arr) {
}
}
-
+
void testLegacy(bool random) {
#if 0
int bases[] = { 3, 2, 4, 5, 7 };
@@ -364,7 +364,7 @@ int k = 4;
#endif
auto dim = NDArrayFactory::create(dimension);
-#if 1
+#if 1
nd4j_printf("C(N:%d K:%d) \n", N, k);
dim.printIndexedBuffer("Dimension");
for (int xind : dimension) {
@@ -385,7 +385,7 @@ for (int e = 0; e < Loop; e++) {
auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count();
values.emplace_back(outerTime);
}
-
+
std::sort(values.begin(), values.end());
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
@@ -411,7 +411,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
constexpr int N = 5;
#endif
-
+
for (int i = 0; i < N; i++) {
arr_dimensions.push_back(bases[i]);
}
@@ -451,7 +451,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
#endif
auto dim = NDArrayFactory::create(dimension);
-#if 1
+#if 1
nd4j_printf("C(N:%d K:%d) \n", N, k);
dim.printIndexedBuffer("Dimension");
for (int xind : dimension) {
@@ -477,14 +477,14 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
//check for the correctness
NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create('c', output_bases) : NDArrayFactory::create(0);
original_argmax(x, dimension, exp);
-
+
#if 0// defined(DEBUG)
x.printIndexedBuffer("X");
exp.printIndexedBuffer("Expected");
z->printIndexedBuffer("Z");
#endif
-
+
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
@@ -505,7 +505,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfLinspace) {
testNewReduction(false, test_corr);
}
#endif
-
+
TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
testNewReduction(true, test_corr);
}
@@ -513,7 +513,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) {
testNewReduction(true, test_corr, 'f');
}
-
+
#if !defined(DEBUG)
TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) {
testLegacy(false);
@@ -1062,39 +1062,6 @@ TEST_F(PlaygroundTests, my) {
delete variableSpace;
}
-TEST_F(PlaygroundTests, my) {
-
- int N = 100;
- int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
- int oH=128,oW=128;
-
- int paddingMode = 1; // 1-SAME, 0-VALID;
- int dataFormat = 1; // 1-NHWC, 0-NCHW
-
- // NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32);
- // NDArray output('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32);
- NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32);
- NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32);
- // NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
- NDArray weights('c', {oC, iC, kH, kW}, sd::DataType::FLOAT32);
- NDArray bias('c', {oC}, sd::DataType::FLOAT32);
-
- input = 5.;
- weights = 3.;
- bias = 1.;
-
- sd::ops::conv2d op;
- auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
-
- auto timeStart = std::chrono::system_clock::now();
- for (int i = 0; i < N; ++i)
- err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
- auto timeEnd = std::chrono::system_clock::now();
- auto time = std::chrono::duration_cast ((timeEnd - timeStart) / N).count();
-
- printf("time: %i \n", time);
-}
-
///////////////////////////////////////////////////////////////////
TEST_F(PlaygroundTests, lstmLayerCellBp_1) {
@@ -1690,6 +1657,52 @@ TEST_F(DeclarableOpsTests15, gru_bp_1) {
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
}
+#include
+//////////////////////////////////////////////////////////////////////
+TEST_F(PlaygroundTests, my) {
+
+ const int N = 10;
+
+ NDArray input('c', {8000000}, sd::DataType::INT32);
+ input.linspace(1);
+ NDArray output = input.dup();
+
+
+ sd::graph::RandomGenerator rng;
+
+ sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true);
+
+ // auto timeStart = std::chrono::system_clock::now();
+ // for (int i = 0; i < N; ++i)
+ // sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true);
+ // auto timeEnd = std::chrono::system_clock::now();
+ // auto time = std::chrono::duration_cast ((timeEnd - timeStart) / N).count();
+ // printf("time: %i \n", time);
+
+ // bool hasDublicates = false;
+ // for(int i = 0; i < output.lengthOf() - 1; ++i)
+ // for(int j = i+1; j < output.lengthOf(); ++j)
+ // if(output.t(i) == output.t(j)) {
+ // hasDublicates = true;
+ // i = output.lengthOf();
+ // break;
+ // }
+
+ ASSERT_TRUE(!input.equalsTo(output));
+
+ bool hasDublicates = false;
+ for(int i = 0; i < input.lengthOf() - 1; ++i)
+ for(int j = i+1; j < input.lengthOf(); ++j)
+ if(input.t(i) == input.t(j)) {
+ hasDublicates = true;
+ i = input.lengthOf();
+ break;
+ }
+ ASSERT_TRUE(!hasDublicates);
+}
+
+
+}
+
*/
-
diff --git a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp b/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp
deleted file mode 100644
index 8481dfde5..000000000
--- a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp
+++ /dev/null
@@ -1,93 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-//
-// @author raver119@gmail.com
-//
-
-#ifndef LIBND4J_SESSIONLOCALTESTS_H
-#define LIBND4J_SESSIONLOCALTESTS_H
-
-#include "testlayers.h"
-#include
-#include
-
-using namespace sd::graph;
-
-class SessionLocalTests : public testing::Test {
-public:
-
-};
-
-TEST_F(SessionLocalTests, BasicTests_1) {
- VariableSpace variableSpace;
- SessionLocalStorage storage(&variableSpace, nullptr);
-
- if (omp_get_max_threads() <= 1)
- return;
-
- PRAGMA_OMP_PARALLEL_FOR_THREADS(4)
- for (int e = 0; e < 4; e++) {
- storage.startSession();
- }
-
- ASSERT_EQ(4, storage.numberOfSessions());
-
- PRAGMA_OMP_PARALLEL_FOR_THREADS(4)
- for (int e = 0; e < 4; e++) {
- storage.endSession();
- }
-
- ASSERT_EQ(0, storage.numberOfSessions());
-}
-
-
-TEST_F(SessionLocalTests, BasicTests_2) {
- VariableSpace variableSpace;
- SessionLocalStorage storage(&variableSpace, nullptr);
-
- if (omp_get_max_threads() <= 1)
- return;
-
- auto alpha = sd::NDArrayFactory::create_('c',{5,5});
- alpha->assign(0.0);
-
- variableSpace.putVariable(-1, alpha);
-
- PRAGMA_OMP_PARALLEL_FOR_THREADS(4)
- for (int e = 0; e < 4; e++) {
- storage.startSession();
-
- auto varSpace = storage.localVariableSpace();
-
- auto arr = varSpace->getVariable(-1)->getNDArray();
- arr->applyScalar(sd::scalar::Add, (float) e+1, *arr);
- }
-
- float lastValue = 0.0f;
- for (int e = 1; e <= 4; e++) {
- auto varSpace = storage.localVariableSpace((Nd4jLong) e);
-
- auto arr = varSpace->getVariable(-1)->getNDArray();
-
- //nd4j_printf("Last value: %f; Current value: %f\n", lastValue, arr->e(0));
-
- ASSERT_NE(lastValue, arr->e(0));
- lastValue = arr->e(0);
- }
-}
-
-#endif //LIBND4J_SESSIONLOCALTESTS_H
diff --git a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt
index 7e01e2847..bbd632d27 100644
--- a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt
+++ b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt
@@ -45,6 +45,21 @@ if ("${BUILD_MKLDNN}")
set(MKLDNN dnnl)
endif()
+if (${HELPERS_armcompute})
+ find_package(ARMCOMPUTE REQUIRED)
+
+ if(ARMCOMPUTE_FOUND)
+ message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}")
+ set(HAVE_ARMCOMPUTE 1)
+ # Add preprocessor definition for ARM Compute NEON
+ add_definitions(-DARMCOMPUTENEON_ENABLED)
+ #build our library with neon support
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon")
+ include_directories(${ARMCOMPUTE_INCLUDE})
+ endif()
+
+endif()
+
# Download and unpack flatbuffers at configure time
configure_file(../../CMakeLists.txt.in flatbuffers-download/CMakeLists.txt)
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
@@ -217,6 +232,10 @@ if ("${BUILD_MKLDNN}")
file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../../include/ops/declarable/platform/mkldnn/*.cpp)
endif()
+if(HAVE_ARMCOMPUTE)
+ file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/armcomputeUtils.h)
+endif()
+
message("CPU backend")
add_definitions(-D__CPUBLAS__=true)
@@ -276,8 +295,9 @@ endforeach(TMP_PATH)
add_executable(runtests ${LOOPS_SOURCES} ${LEGACY_SOURCES} ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
- ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
+ ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES}
+ ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES})
-target_link_libraries(runtests gtest ${MKLDNN} gtest_main ${BLAS_LIBRARIES})
+target_link_libraries(runtests gtest ${MKLDNN} ${ARMCOMPUTE_LIBRARIES} gtest_main ${BLAS_LIBRARIES})
diff --git a/libnd4j/tests_cpu/resources/simpleif_0_alt.fb b/libnd4j/tests_cpu/resources/simpleif_0_alt.fb
new file mode 100644
index 000000000..4a7e751c3
Binary files /dev/null and b/libnd4j/tests_cpu/resources/simpleif_0_alt.fb differ
diff --git a/libnd4j/tests_cpu/resources/simplewhile_1.fb b/libnd4j/tests_cpu/resources/simplewhile_1.fb
new file mode 100644
index 000000000..c4fa26e2a
Binary files /dev/null and b/libnd4j/tests_cpu/resources/simplewhile_1.fb differ
diff --git a/libnd4j/tests_cpu/resources/simplewhile_nested.fb b/libnd4j/tests_cpu/resources/simplewhile_nested.fb
new file mode 100644
index 000000000..9404b98c4
Binary files /dev/null and b/libnd4j/tests_cpu/resources/simplewhile_nested.fb differ
diff --git a/libnd4j/tests_cpu/resources/while_iter3.fb b/libnd4j/tests_cpu/resources/while_iter3.fb
new file mode 100644
index 000000000..4b0e86979
Binary files /dev/null and b/libnd4j/tests_cpu/resources/while_iter3.fb differ
diff --git a/python4j/pom.xml b/python4j/pom.xml
index 57af8f1bb..3f1d026a5 100644
--- a/python4j/pom.xml
+++ b/python4j/pom.xml
@@ -25,7 +25,7 @@
4.0.0
- org.eclipse
+ org.nd4j
python4j-parent
pom
@@ -41,10 +41,14 @@
provided
+ org.slf4j
+ slf4j-api
+ 1.6.6
+
ch.qos.logback
logback-classic
${logback.version}
- test
+ test
junit
@@ -62,5 +66,10 @@
jsr305
3.0.2
+
+ org.slf4j
+ slf4j-api
+ 1.6.6
+
\ No newline at end of file
diff --git a/python4j/python4j-core/pom.xml b/python4j/python4j-core/pom.xml
index b429d8272..26e77b8d1 100644
--- a/python4j/python4j-core/pom.xml
+++ b/python4j/python4j-core/pom.xml
@@ -21,7 +21,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
python4j-parent
- org.eclipse
+ org.nd4j
1.0.0-SNAPSHOT
jar
@@ -39,6 +39,5 @@
cpython-platform
${cpython-platform.version}
-
\ No newline at end of file
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/Python.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/Python.java
similarity index 99%
rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/Python.java
rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/Python.java
index fd6fff112..03c2fdaab 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/Python.java
+++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/Python.java
@@ -15,7 +15,7 @@
******************************************************************************/
-package org.eclipse.python4j;
+package org.nd4j.python4j;
import org.bytedeco.cpython.PyObject;
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java
similarity index 86%
rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java
rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java
index a34d8a239..0090e38d4 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java
+++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java
@@ -14,13 +14,15 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
-package org.eclipse.python4j;
+package org.nd4j.python4j;
import javax.lang.model.SourceVersion;
+import java.io.Closeable;
import java.util.HashSet;
import java.util.Set;
+import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
/**
@@ -46,6 +48,31 @@ public class PythonContextManager {
init();
}
+
+ public static class Context implements Closeable{
+ private final String name;
+ private final String previous;
+ private final boolean temp;
+ public Context(){
+ name = "temp_" + UUID.randomUUID().toString().replace("-", "_");
+ temp = true;
+ previous = getCurrentContext();
+ setContext(name);
+ }
+ public Context(String name){
+ this.name = name;
+ temp = false;
+ previous = getCurrentContext();
+ setContext(name);
+ }
+
+ @Override
+ public void close(){
+ setContext(previous);
+ if (temp) deleteContext(name);
+ }
+ }
+
private static void init() {
if (init.get()) return;
new PythonExecutioner();
@@ -76,7 +103,18 @@ public class PythonContextManager {
}
private static boolean validateContextName(String s) {
- return SourceVersion.isIdentifier(s) && !s.startsWith(COLLAPSED_KEY);
+ for (int i=0; i= '0' && c <= '9'){
+ return false;
+ }
+ }
+ if (!(c=='_' || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9'))){
+ return false;
+ }
+ }
+ return true;
}
private static String getContextPrefix(String contextName) {
@@ -190,6 +228,7 @@ public class PythonContextManager {
setContext(tempContext);
deleteContext(currContext);
setContext(currContext);
+ deleteContext(tempContext);
}
/**
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonException.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonException.java
similarity index 98%
rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonException.java
rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonException.java
index a9bbf596c..e8f64f2be 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonException.java
+++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonException.java
@@ -14,7 +14,7 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
-package org.eclipse.python4j;
+package org.nd4j.python4j;
/**
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java
similarity index 90%
rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java
rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java
index 57e1a22ae..bc48b0e98 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java
+++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java
@@ -15,7 +15,7 @@
******************************************************************************/
-package org.eclipse.python4j;
+package org.nd4j.python4j;
import org.bytedeco.cpython.PyObject;
@@ -42,7 +42,6 @@ public class PythonExecutioner {
private final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.eclipse.python4j.path";
private final static String JAVACPP_PYTHON_APPEND_TYPE = "org.eclipse.python4j.path.append";
private final static String DEFAULT_APPEND_TYPE = "before";
-
static {
init();
}
@@ -55,6 +54,11 @@ public class PythonExecutioner {
initPythonPath();
PyEval_InitThreads();
Py_InitializeEx(0);
+ for (PythonType type: PythonTypes.get()){
+ type.init();
+ }
+ // Constructors of custom types may contain initialization code that should
+ // run on the main the thread.
}
/**
@@ -110,6 +114,8 @@ public class PythonExecutioner {
getVariables(Arrays.asList(pyVars));
}
+
+
/**
* Gets the variable with the given name from the interpreter.
*
@@ -205,9 +211,9 @@ public class PythonExecutioner {
*
* @return
*/
- public static List getAllVariables() {
+ public static PythonVariables getAllVariables() {
PythonGIL.assertThreadSafe();
- List ret = new ArrayList<>();
+ PythonVariables ret = new PythonVariables();
PyObject main = PyImport_ImportModule("__main__");
PyObject globals = PyModule_GetDict(main);
PyObject keys = PyDict_Keys(globals);
@@ -259,7 +265,7 @@ public class PythonExecutioner {
* @param inputs
* @return
*/
- public static List execAndReturnAllVariables(String code, List inputs) {
+ public static PythonVariables execAndReturnAllVariables(String code, List inputs) {
setVariables(inputs);
simpleExec(getWrappedCode(code));
return getAllVariables();
@@ -271,7 +277,7 @@ public class PythonExecutioner {
* @param code
* @return
*/
- public static List execAndReturnAllVariables(String code) {
+ public static PythonVariables execAndReturnAllVariables(String code) {
simpleExec(getWrappedCode(code));
return getAllVariables();
}
@@ -279,25 +285,22 @@ public class PythonExecutioner {
private static synchronized void initPythonPath() {
try {
String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY);
+
+ List packagesList = new ArrayList<>();
+ packagesList.addAll(Arrays.asList(cachePackages()));
+ for (PythonType type: PythonTypes.get()){
+ packagesList.addAll(Arrays.asList(type.packages()));
+ }
+ //// TODO: fix in javacpp
+ packagesList.add(new File(python.cachePackage(), "site-packages"));
+
+ File[] packages = packagesList.toArray(new File[0]);
+
if (path == null) {
- File[] packages = cachePackages();
-
- //// TODO: fix in javacpp
- File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
- File[] packages2 = new File[packages.length + 1];
- for (int i = 0; i < packages.length; i++) {
- //System.out.println(packages[i].getAbsolutePath());
- packages2[i] = packages[i];
- }
- packages2[packages.length] = sitePackagesWindows;
- //System.out.println(sitePackagesWindows.getAbsolutePath());
- packages = packages2;
- //////////
-
Py_SetPath(packages);
} else {
StringBuffer sb = new StringBuffer();
- File[] packages = cachePackages();
+
JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase());
switch (pathAppendValue) {
case BEFORE:
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGC.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGC.java
similarity index 99%
rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGC.java
rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGC.java
index 5531b67d3..e18d2072d 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGC.java
+++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGC.java
@@ -15,7 +15,7 @@
******************************************************************************/
-package org.eclipse.python4j;
+package org.nd4j.python4j;
import org.bytedeco.cpython.PyObject;
import org.bytedeco.javacpp.Pointer;
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGIL.java
similarity index 96%
rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java
rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGIL.java
index 46b3db431..3a88253e0 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java
+++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGIL.java
@@ -14,11 +14,10 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
-package org.eclipse.python4j;
+package org.nd4j.python4j;
import org.bytedeco.cpython.PyThreadState;
-import org.omg.SendingContext.RunTime;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -90,4 +89,8 @@ public class PythonGIL implements AutoCloseable {
PyEval_SaveThread();
PyEval_RestoreThread(mainThreadState);
}
+
+ public static boolean locked(){
+ return acquired.get();
+ }
}
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonJob.java
similarity index 93%
rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java
rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonJob.java
index cdbb1b81d..f357388f7 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java
+++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonJob.java
@@ -14,31 +14,34 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
-package org.eclipse.python4j;
+package org.nd4j.python4j;
import lombok.Builder;
import lombok.Data;
-import lombok.NoArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
import javax.annotation.Nonnull;
import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
-@Data
-@NoArgsConstructor
/**
* PythonJob is the right abstraction for executing multiple python scripts
* in a multi thread stateful environment. The setup-and-run mode allows your
* "setup" code (imports, model loading etc) to be executed only once.
*/
+@Data
+@Slf4j
public class PythonJob {
+
private String code;
private String name;
private String context;
- private boolean setupRunMode;
+ private final boolean setupRunMode;
private PythonObject runF;
+ private final AtomicBoolean setupDone = new AtomicBoolean(false);
static {
new PythonExecutioner();
@@ -63,7 +66,6 @@ public class PythonJob {
if (PythonContextManager.hasContext(context)) {
throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!");
}
- if (setupRunMode) setup();
}
@@ -71,17 +73,18 @@ public class PythonJob {
* Clears all variables in current context and calls setup()
*/
public void clearState(){
- String context = this.context;
- PythonContextManager.setContext("main");
- PythonContextManager.deleteContext(context);
- this.context = context;
+ PythonContextManager.setContext(this.context);
+ PythonContextManager.reset();
+ setupDone.set(false);
setup();
}
public void setup(){
+ if (setupDone.get()) return;
try (PythonGIL gil = PythonGIL.lock()) {
PythonContextManager.setContext(context);
PythonObject runF = PythonExecutioner.getVariable("run");
+
if (runF == null || runF.isNone() || !Python.callable(runF)) {
PythonExecutioner.exec(code);
runF = PythonExecutioner.getVariable("run");
@@ -98,10 +101,12 @@ public class PythonJob {
if (!setupF.isNone()) {
setupF.call();
}
+ setupDone.set(true);
}
}
public void exec(List inputs, List outputs) {
+ if (setupRunMode)setup();
try (PythonGIL gil = PythonGIL.lock()) {
try (PythonGC _ = PythonGC.watch()) {
PythonContextManager.setContext(context);
@@ -139,6 +144,7 @@ public class PythonJob {
}
public List execAndReturnAllVariables(List inputs){
+ if (setupRunMode)setup();
try (PythonGIL gil = PythonGIL.lock()) {
try (PythonGC _ = PythonGC.watch()) {
PythonContextManager.setContext(context);
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java
similarity index 77%
rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java
rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java
index f8ec17ed9..94b60d320 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java
+++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java
@@ -14,7 +14,7 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
-package org.eclipse.python4j;
+package org.nd4j.python4j;
import org.bytedeco.cpython.PyObject;
@@ -147,7 +147,8 @@ public class PythonObject {
}
PythonObject pyArgs;
PythonObject pyKwargs;
- if (args == null) {
+
+ if (args == null || args.isEmpty()) {
pyArgs = new PythonObject(PyTuple_New(0));
} else {
PythonObject argsList = PythonTypes.convert(args);
@@ -158,6 +159,7 @@ public class PythonObject {
} else {
pyKwargs = PythonTypes.convert(kwargs);
}
+
PythonObject ret = new PythonObject(
PyObject_Call(
nativePythonObject,
@@ -165,7 +167,9 @@ public class PythonObject {
pyKwargs == null ? null : pyKwargs.nativePythonObject
)
);
+
PythonGC.keep(ret);
+
return ret;
}
@@ -241,4 +245,48 @@ public class PythonObject {
PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject);
}
+
+ public PythonObject abs(){
+ return new PythonObject(PyNumber_Absolute(nativePythonObject));
+ }
+ public PythonObject add(PythonObject pythonObject){
+ return new PythonObject(PyNumber_Add(nativePythonObject, pythonObject.nativePythonObject));
+ }
+ public PythonObject sub(PythonObject pythonObject){
+ return new PythonObject(PyNumber_Subtract(nativePythonObject, pythonObject.nativePythonObject));
+ }
+ public PythonObject mod(PythonObject pythonObject){
+ return new PythonObject(PyNumber_Divmod(nativePythonObject, pythonObject.nativePythonObject));
+ }
+ public PythonObject mul(PythonObject pythonObject){
+ return new PythonObject(PyNumber_Multiply(nativePythonObject, pythonObject.nativePythonObject));
+ }
+ public PythonObject trueDiv(PythonObject pythonObject){
+ return new PythonObject(PyNumber_TrueDivide(nativePythonObject, pythonObject.nativePythonObject));
+ }
+ public PythonObject floorDiv(PythonObject pythonObject){
+ return new PythonObject(PyNumber_FloorDivide(nativePythonObject, pythonObject.nativePythonObject));
+ }
+ public PythonObject matMul(PythonObject pythonObject){
+ return new PythonObject(PyNumber_MatrixMultiply(nativePythonObject, pythonObject.nativePythonObject));
+ }
+
+ public void addi(PythonObject pythonObject){
+ PyNumber_InPlaceAdd(nativePythonObject, pythonObject.nativePythonObject);
+ }
+ public void subi(PythonObject pythonObject){
+ PyNumber_InPlaceSubtract(nativePythonObject, pythonObject.nativePythonObject);
+ }
+ public void muli(PythonObject pythonObject){
+ PyNumber_InPlaceMultiply(nativePythonObject, pythonObject.nativePythonObject);
+ }
+ public void trueDivi(PythonObject pythonObject){
+ PyNumber_InPlaceTrueDivide(nativePythonObject, pythonObject.nativePythonObject);
+ }
+ public void floorDivi(PythonObject pythonObject){
+ PyNumber_InPlaceFloorDivide(nativePythonObject, pythonObject.nativePythonObject);
+ }
+ public void matMuli(PythonObject pythonObject){
+ PyNumber_InPlaceMatrixMultiply(nativePythonObject, pythonObject.nativePythonObject);
+ }
}
diff --git a/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java
new file mode 100644
index 000000000..bce8809f5
--- /dev/null
+++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java
@@ -0,0 +1,127 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+
+package org.nd4j.python4j;
+
+import org.apache.commons.io.IOUtils;
+import org.bytedeco.javacpp.Loader;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+
+public class PythonProcess {
+ private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
+ public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
+ String[] allArgs = new String[arguments.length + 1];
+ for (int i = 0; i < arguments.length; i++){
+ allArgs[i + 1] = arguments[i];
+ }
+ allArgs[0] = pythonExecutable;
+ ProcessBuilder pb = new ProcessBuilder(allArgs);
+ Process process = pb.start();
+ String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
+ process.waitFor();
+ return out;
+
+ }
+
+ public static void run(String... arguments)throws IOException, InterruptedException{
+ String[] allArgs = new String[arguments.length + 1];
+ for (int i = 0; i < arguments.length; i++){
+ allArgs[i + 1] = arguments[i];
+ }
+ allArgs[0] = pythonExecutable;
+ ProcessBuilder pb = new ProcessBuilder(allArgs);
+ pb.inheritIO().start().waitFor();
+ }
+ public static void pipInstall(String packageName) throws PythonException{
+ try{
+ run("-m", "pip", "install", packageName);
+ }catch(Exception e){
+ throw new PythonException("Error installing package " + packageName, e);
+ }
+
+ }
+
+ public static void pipInstall(String packageName, String version){
+ pipInstall(packageName + "==" + version);
+ }
+
+ public static void pipUninstall(String packageName) throws PythonException{
+ try{
+ run("-m", "pip", "uninstall", packageName);
+ }catch(Exception e){
+ throw new PythonException("Error uninstalling package " + packageName, e);
+ }
+
+ }
+ public static void pipInstallFromGit(String gitRepoUrl){
+ if (!gitRepoUrl.contains("://")){
+ gitRepoUrl = "git://" + gitRepoUrl;
+ }
+ try{
+ run("-m", "pip", "install", "git+", gitRepoUrl);
+ }catch(Exception e){
+ throw new PythonException("Error installing package from " + gitRepoUrl, e);
+ }
+
+ }
+
+ public static String getPackageVersion(String packageName){
+ String out;
+ try{
+ out = runAndReturn("-m", "pip", "show", packageName);
+ } catch (Exception e){
+ throw new PythonException("Error finding version for package " + packageName, e);
+ }
+
+ if (!out.contains("Version: ")){
+ throw new PythonException("Can't find package " + packageName);
+ }
+ String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
+ return pkgVersion;
+ }
+
+ public static boolean isPackageInstalled(String packageName){
+ try{
+ String out = runAndReturn("-m", "pip", "show", packageName);
+ return !out.isEmpty();
+ }catch (Exception e){
+ throw new PythonException("Error checking if package is installed: " +packageName, e);
+ }
+
+ }
+
+ public static void pipInstallFromRequirementsTxt(String path){
+ try{
+ run("-m", "pip", "install","-r", path);
+ }catch (Exception e){
+ throw new PythonException("Error installing packages from " + path, e);
+ }
+ }
+
+ public static void pipInstallFromSetupScript(String path, boolean inplace){
+
+ try{
+ run(path, inplace?"develop":"install");
+ }catch (Exception e){
+ throw new PythonException("Error installing package from " + path, e);
+ }
+
+ }
+
+}
\ No newline at end of file
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonType.java
similarity index 72%
rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java
rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonType.java
index b4806aa37..79b0ccaab 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java
+++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonType.java
@@ -14,9 +14,11 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
-package org.eclipse.python4j;
+package org.nd4j.python4j;
+import java.io.File;
+
public abstract class PythonType {
private final String name;
@@ -43,5 +45,25 @@ public abstract class PythonType {
return name;
}
+ @Override
+ public boolean equals(Object obj){
+ if (!(obj instanceof PythonType)){
+ return false;
+ }
+ PythonType other = (PythonType)obj;
+ return this.getClass().equals(other.getClass()) && this.name.equals(other.name);
+ }
+
+ public PythonObject pythonType(){
+ return null;
+ }
+
+ public File[] packages(){
+ return new File[0];
+ }
+
+ public void init(){ //not to be called from constructor
+
+ }
}
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java
similarity index 58%
rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java
rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java
index 0dc20f712..089c8aefe 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java
+++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java
@@ -14,11 +14,18 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
-package org.eclipse.python4j;
+package org.nd4j.python4j;
import org.bytedeco.cpython.PyObject;
+import org.bytedeco.javacpp.BytePointer;
+import org.bytedeco.javacpp.Pointer;
+import sun.nio.ch.DirectBuffer;
+import java.lang.reflect.Field;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.util.*;
import static org.bytedeco.cpython.global.python.*;
@@ -28,7 +35,7 @@ public class PythonTypes {
private static List getPrimitiveTypes() {
- return Arrays.asList(STR, INT, FLOAT, BOOL);
+ return Arrays.asList(STR, INT, FLOAT, BOOL, BYTES);
}
private static List getCollectionTypes() {
@@ -36,8 +43,13 @@ public class PythonTypes {
}
private static List getExternalTypes() {
- //TODO service loader
- return new ArrayList<>();
+ List ret = new ArrayList<>();
+ ServiceLoader sl = ServiceLoader.load(PythonType.class);
+ Iterator iter = sl.iterator();
+ while (iter.hasNext()) {
+ ret.add(iter.next());
+ }
+ return ret;
}
public static List get() {
@@ -48,15 +60,17 @@ public class PythonTypes {
return ret;
}
- public static PythonType get(String name) {
+ public static PythonType get(String name) {
for (PythonType pt : get()) {
if (pt.getName().equals(name)) { // TODO use map instead?
return pt;
}
+
}
throw new PythonException("Unknown python type: " + name);
}
+
public static PythonType getPythonTypeForJavaObject(Object javaObject) {
for (PythonType pt : get()) {
if (pt.accepts(javaObject)) {
@@ -66,7 +80,7 @@ public class PythonTypes {
throw new PythonException("Unable to find python type for java type: " + javaObject.getClass());
}
- public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) {
+ public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) {
PyObject pyType = PyObject_Type(pythonObject.getNativePythonObject());
try {
String pyTypeStr = PythonTypes.STR.toJava(new PythonObject(pyType, false));
@@ -75,6 +89,14 @@ public class PythonTypes {
String pyTypeStr2 = "";
if (pyTypeStr.equals(pyTypeStr2)) {
return pt;
+ } else {
+ try (PythonGC gc = PythonGC.watch()) {
+ PythonObject pyType2 = pt.pythonType();
+ if (pyType2 != null && Python.isinstance(pythonObject, pyType2)) {
+ return pt;
+ }
+ }
+
}
}
throw new PythonException("Unable to find converter for python object of type " + pyTypeStr);
@@ -212,12 +234,53 @@ public class PythonTypes {
public static final PythonType LIST = new PythonType("list", List.class) {
+ @Override
+ public boolean accepts(Object javaObject) {
+ return (javaObject instanceof List || javaObject.getClass().isArray());
+ }
+
@Override
public List adapt(Object javaObject) {
if (javaObject instanceof List) {
return (List) javaObject;
- } else if (javaObject instanceof Object[]) {
- return Arrays.asList((Object[]) javaObject);
+ } else if (javaObject.getClass().isArray()) {
+ List