diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index 743e16710..b0cc17376 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -57,10 +57,8 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.util.ModelSerializer; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; +import org.junit.*; +import org.junit.rules.TemporaryFolder; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; @@ -82,6 +80,7 @@ import org.nd4j.common.resources.Resources; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.File; import java.io.IOException; import java.util.*; @@ -91,6 +90,9 @@ import static org.junit.Assert.*; @Slf4j public class TestComputationGraphNetwork extends BaseDL4JTest { + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + private static ComputationGraphConfiguration getIrisGraphConfiguration() { return new NeuralNetConfiguration.Builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() @@ -2177,4 +2179,40 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2); cg.fit(new DataSet(in, label)); } + + @Test + public void testMergeNchw() throws Exception { + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .convolutionMode(ConvolutionMode.Same) + .graphBuilder() + .addInputs("in") + .layer("l0", new ConvolutionLayer.Builder() + .nOut(16) + .kernelSize(2,2).stride(1,1) + .build(), "in") + .layer("l1", new ConvolutionLayer.Builder() + .nOut(8) + .kernelSize(2,2).stride(1,1) + .build(), "in") + .addVertex("merge", new MergeVertex(), "l0", "l1") + .layer("out", new CnnLossLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge") + .setOutputs("out") + .setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC)) + .build(); + + ComputationGraph cg = new ComputationGraph(conf); + cg.init(); + + INDArray[] in = new INDArray[]{Nd4j.rand(DataType.FLOAT, 1, 32, 32, 3)}; + INDArray out = cg.outputSingle(in); + + File dir = testDir.newFolder(); + File f = new File(dir, "net.zip"); + cg.save(f); + + ComputationGraph c2 = ComputationGraph.load(f, true); + INDArray out2 = c2.outputSingle(in); + + assertEquals(out, out2); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java index 46a872fd8..acb6afa2c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java @@ -66,8 +66,8 @@ import org.nd4j.shade.jackson.annotation.JsonProperty; * @author Alex Black */ @Data -@JsonIgnoreProperties({"mask", "helper", "helperCountFail"}) -@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail"}) +@JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"}) +@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"}) @Slf4j public class Dropout implements IDropout { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java index 726a68403..c7a4fec63 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.conf.graph; +import lombok.Data; import lombok.val; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.RNNFormat; @@ -38,6 +39,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * -> [numExamples,depth1 + depth2,width,height]}
* @author Alex Black */ +@Data public class MergeVertex extends GraphVertex { protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java index 68a012b72..c1eff1dce 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java @@ -141,7 +141,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm); sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats()); - System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); +// System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); File f = testDir.newFolder(); DataSetIterator iter = new MnistDataSetIterator(16, true, 12345); int count = 0; @@ -208,10 +208,10 @@ public class GradientSharingTrainingTest extends BaseSparkTest { } INDArray paramsAfter = after.params(); - System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); - System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); - System.out.println(Arrays.toString( - Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); +// System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); +// System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); +// System.out.println(Arrays.toString( +// Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); assertNotEquals(paramsBefore, paramsAfter); @@ -235,7 +235,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { } - @Test + @Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985 public void differentNetsTrainingTest() throws Exception { int batch = 3; diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt index 106401b31..0631763c2 100755 --- a/libnd4j/CMakeLists.txt +++ b/libnd4j/CMakeLists.txt @@ -131,6 +131,23 @@ if(NOT SD_CUDA) endif() endif() +#arm-compute entry +if(${HELPERS_armcompute}) + find_package(ARMCOMPUTE REQUIRED) + + if(ARMCOMPUTE_FOUND) + message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}") + set(HAVE_ARMCOMPUTE 1) + # Add preprocessor definition for ARM Compute NEON + add_definitions(-DARMCOMPUTENEON_ENABLED) + #build our library with neon support + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") + include_directories(${ARMCOMPUTE_INCLUDE}) + message("----${ARMCOMPUTE_INCLUDE}---") + endif() + +endif() + # new mkl-dnn entry if (${HELPERS_mkldnn}) diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index fb1dc066e..b6bd1f7c0 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -146,6 +146,10 @@ if (HAVE_MKLDNN) file(GLOB_RECURSE CUSTOMOPS_MKLDNN_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h) endif() +if(HAVE_ARMCOMPUTE) + file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/*.h) +endif() + if(SD_CUDA) message("Build cublas") find_package(CUDA) @@ -243,7 +247,7 @@ if(SD_CUDA) ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} - ${CUSTOMOPS_GENERIC_SOURCES} + ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ) if (WIN32) @@ -351,8 +355,8 @@ elseif(SD_CPU) add_definitions(-D__CPUBLAS__=true) add_library(samediff_obj OBJECT ${LEGACY_SOURCES} ${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} - ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} - ${OPS_SOURCES} ${PERF_SOURCES}) + ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} + ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) if(IOS) add_library(${SD_LIBRARY_NAME} STATIC $) else() @@ -378,12 +382,12 @@ elseif(SD_CPU) if (NOT BLAS_LIBRARIES) set(BLAS_LIBRARIES "") endif() - target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES}) + target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES}) if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}") message(STATUS "Building minifier...") add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp) - target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES}) + target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES}) endif() if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9) diff --git a/libnd4j/cmake/FindARMCOMPUTE.cmake b/libnd4j/cmake/FindARMCOMPUTE.cmake new file mode 100644 index 000000000..ae0e1fbba --- /dev/null +++ b/libnd4j/cmake/FindARMCOMPUTE.cmake @@ -0,0 +1,74 @@ +################################################################################ +# Copyright (c) 2020 Konduit K.K. +# +# This program and the accompanying materials are made available under the +# terms of the Apache License, Version 2.0 which is available at +# https://www.apache.org/licenses/LICENSE-2.0. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + + +### Find ARM COMPUTE LIBRARY STATIC libraries + +SET (COMPUTE_INCLUDE_DIRS + /usr/include + ${ARMCOMPUTE_ROOT} + ${ARMCOMPUTE_ROOT}/include + ${ARMCOMPUTE_ROOT}/applications + ${ARMCOMPUTE_ROOT}/applications/arm_compute +) + + +SET (COMPUTE_LIB_DIRS + /lib + /usr/lib + ${ARMCOMPUTE_ROOT} + ${ARMCOMPUTE_ROOT}/lib + ${ARMCOMPUTE_ROOT}/build +) + +find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h + PATHS ${COMPUTE_INCLUDE_DIRS} + NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) + +find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h) + +find_path(HALF_INCLUDE half/half.hpp) +find_path(HALF_INCLUDE half/half.hpp + PATHS ${ARMCOMPUTE_ROOT}/include + NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) +include_directories(SYSTEM ${HALF_INCLUDE}) + +# Find the Arm Compute libraries if not already specified +if (NOT DEFINED ARMCOMPUTE_LIBRARIES) + + find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static + PATHS ${COMPUTE_LIB_DIRS} + PATH_SUFFIXES "Release" + NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) + + find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static + PATHS ${COMPUTE_LIB_DIRS} + PATH_SUFFIXES "Release" + NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) + # In case it wasn't there, try a default search (will work in cases where + # the library has been installed into a standard location) + find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static) + find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static) + + set(ARMCOMPUTE_LIBRARIES ${ARMCOMPUTE_LIBRARY} ${ARMCOMPUTE_CORE_LIBRARY} ) +endif() + + +INCLUDE(FindPackageHandleStandardArgs) + +FIND_PACKAGE_HANDLE_STANDARD_ARGS(ARMCOMPUTE REQUIRED_VARS ARMCOMPUTE_INCLUDE ARMCOMPUTE_LIBRARIES) + diff --git a/libnd4j/include/cnpy/cnpy.h b/libnd4j/include/cnpy/cnpy.h index ea847c3e7..c84623599 100644 --- a/libnd4j/include/cnpy/cnpy.h +++ b/libnd4j/include/cnpy/cnpy.h @@ -69,7 +69,7 @@ namespace cnpy { } }; - struct ND4J_EXPORT npz_t : public std::unordered_map { + struct ND4J_EXPORT npz_t : public std::map { void destruct() { npz_t::iterator it = this->begin(); for(; it != this->end(); ++it) (*it).second.destruct(); diff --git a/libnd4j/include/config.h.in b/libnd4j/include/config.h.in index 1e63552d0..c858dd765 100644 --- a/libnd4j/include/config.h.in +++ b/libnd4j/include/config.h.in @@ -3,6 +3,8 @@ #cmakedefine HAVE_MKLDNN +#cmakedefine HAVE_ARMCOMPUTE + #cmakedefine MKLDNN_PATH "@MKLDNN_PATH@" #cmakedefine HAVE_OPENBLAS diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp index eced3c2b4..b03d19451 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp @@ -45,18 +45,18 @@ namespace sd { DECLARE_TYPES(max_pool_with_argmax) { getOpDescriptor() ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setAllowedOutputTypes(1, {ALL_INTS}); + ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes(1, {ALL_INDICES}); } DECLARE_SHAPE_FN(max_pool_with_argmax) { + auto in = inputShape->at(0); + auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64; + auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in)); + auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, dtype)); - auto in = inputShape->at(0); - auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in)); - auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, DataType::INT64)); - - return SHAPELIST(valuesShape, indicesShape); + return SHAPELIST(valuesShape, indicesShape); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 8f45c696b..7e66d4b11 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -215,7 +215,9 @@ namespace helpers { auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]); auto result = -1; //auto loop = PRAGMA_THREADS_FOR { - auto start = column, stop = rowNum, increment = 1; + auto start = column; + auto stop = rowNum; + auto increment = 1; for (auto rowCounter = start; rowCounter < stop; rowCounter++) { Nd4jLong xPos[] = {rowCounter, column}; auto xIndex = shape::getOffset(compoundShape, xPos, 0); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp index a458b5eff..ebb9d53fa 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp @@ -73,7 +73,7 @@ namespace helpers { } void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { - BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp index ea529112d..2ffbfc95f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp @@ -16,7 +16,8 @@ // // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 -// +// implementation is based on following article: +// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167 @@ -31,96 +32,167 @@ namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// +// Fisher-Yates shuffle template -void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { +static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) { + + for(Nd4jLong i = len-1; i > 0; --i) { + const Nd4jLong j = rng.relativeLong(ind++) % (i + 1); + if(i != j) + math::nd4j_swap(buff[i*ews], buff[j*ews]); + } +} + +////////////////////////////////////////////////////////////////////////// +// mutual shuffle of two adjacent already shuffled ranges with length len1 and (totLen - len1) correspondingly +template +static void mergeShuffle(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len1, const Nd4jLong& totLen, const Nd4jLong& ews, Nd4jLong ind) { + + Nd4jLong beg = 0; // beginning + Nd4jLong mid = len1; // middle + + while (true) { + if(rng.relativeLong(ind++) % 2) { + if(mid == totLen) + break; + math::nd4j_swap(buff[ews * beg], buff[ews * mid++]); + } else { + if(beg == mid) + break; + } + ++beg; + } + + // fisherYates + while (beg < totLen) { + const Nd4jLong j = rng.relativeLong(ind++) % (beg + 1); + if(beg != j) + math::nd4j_swap(buff[ews * beg], buff[ews * j]); + ++beg; + } +} + +////////////////////////////////////////////////////////////////////////// +template +static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - // check edge cases first - int temp; const int firstDim = input.sizeAt(0); + int temp; + if(input.lengthOf() == 1 || firstDim == 1) { if(!isInplace) output.assign(input); } - else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) { + else if (shape::isCommonVector(input.shapeInfo(), temp)) { - // apply Fisher-Yates shuffle - if(isInplace) { - //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold()) - for(int i = firstDim-1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - if(i == r) - continue; - T t0 = input.t(i); - T t1 = input.t(r); - //math::nd4j_swap(input(i), input(r)); - input.r(i) = t1; - input.r(r) = t0; - } + NDArray* arr = &input; + + if (!isInplace) { + output.assign(input); + arr = &output; } - else { - std::vector indices(firstDim); - std::iota(indices.begin(), indices.end(), 0); - output.p(Nd4jLong(0), input.e(0)); - // FIXME: parallelism!! - for(int i = firstDim-1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - output.r(i) = input.t(indices[r]); - if(i == r) - continue; + const Nd4jLong ews = arr->ews(); - output.r(r) = input.t(indices[i]); - math::nd4j_swap(indices[i], indices[r]); + const Nd4jLong len = arr->lengthOf(); + const Nd4jLong threshold = 1<<22; // this number was deduced from diagram in article + + int power = 0; + while ((len >> power) > threshold) + ++power; + + const Nd4jLong numChunks = 1 << power; + + auto funcFisherYates = PRAGMA_THREADS_FOR { + + for (auto i = start; i < stop; ++i) { + + Nd4jLong offset = (len * i) >> power; + Nd4jLong currLen = ((len * (i + 1)) >> power) - offset; + fisherYates(rng, arr->bufferAsT() + offset*ews, currLen, ews, offset); } - rng.rewindH(firstDim-1); - } + }; + + auto funcMerge = PRAGMA_THREADS_FOR { + + for (int64_t i = start, k = 1; i < stop; i += increment, ++k) { + Nd4jLong offset = len * i >> power; + Nd4jLong len1 = (len * (i + increment/2) >> power) - offset; + Nd4jLong totLen = (len * (i + increment) >> power) - offset; + mergeShuffle(rng, arr->bufferAsT() + offset*ews, len1, totLen, ews, len * k + offset); + } + }; + + samediff::Threads::parallel_for(funcFisherYates, 0, numChunks); + + for (int j = 1; j < numChunks; j += j) + samediff::Threads::parallel_for(funcMerge, 0, numChunks, 2*j); + + // #pragma omp parallel for + // for (uint i = 0; i < numChunks; ++i) { + + // Nd4jLong offset = (len * i) >> power; + // Nd4jLong currLen = ((len * (i + 1)) >> power) - offset; + // fisherYates(rng, arr->bufferAsT() + offset*ews, currLen, ews, offset); + // } + + // for (uint j = 1; j < numChunks; j += j) { + // #pragma omp parallel for + // for (auto i = 0; i < numChunks; i += 2*j) { + // Nd4jLong offset = len * i >> power; + // Nd4jLong len1 = (len * (i + j) >> power) - offset; + // Nd4jLong totLen = (len * (i + 2*j) >> power) - offset; + // mergeShuffle(rng, arr->bufferAsT() + offset*ews, len1, totLen, ews, len * j + offset); + // } + // } + + rng.rewindH((len + 1) * power); } else { - // evaluate sub-arrays list of input array through all dimensions excluding first one - std::vector dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); - auto subArrsListIn = input.allTensorsAlongDimension(dimensions); + auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); - // apply Fisher-Yates shuffle if(isInplace) { - //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().elementwiseThreshold()) - for(int i = firstDim - 1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - if(i == r) - continue; - subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r)); + auto subArrsList = input.allTensorsAlongDimension(dimsToExclude); + + // Fisher-Yates shuffle + for(int i = firstDim - 1; i > 0; --i) { + const int j = rng.relativeInt(i) % (i + 1); + if(i != j) + subArrsList.at(i)->swapUnsafe(*subArrsList.at(j)); } } else { - // evaluate sub-arrays list of output array through all dimensions excluding first one - auto subArrsListOut = output.allTensorsAlongDimension(dimensions); + + auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude); + auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude); + std::vector indices(firstDim); - std::iota(indices.begin(), indices.end(), 0); - bool isZeroShuffled = false; - //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold()) - for(int i = firstDim - 1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r])); - if(r == 0) - isZeroShuffled = true; - if(i == r) - continue; - subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i])); - math::nd4j_swap(indices[i], indices[r]); - } - if(!isZeroShuffled) - subArrsListOut.at(0)->assign(subArrsListIn.at(0)); + std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1 + + // shuffle indices + fisherYates(rng, indices.data(), firstDim, 1, 0); + + auto func = PRAGMA_THREADS_FOR { + + for (auto i = start; i < stop; ++i) + subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i])); + }; + + samediff::Threads::parallel_for(func, 0, firstDim); } + rng.rewindH(firstDim-1); } - } - void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES); - } +void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { + BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES); +} + } } } + diff --git a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu index cbcd35ffe..400c25f88 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu @@ -53,7 +53,7 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, const int coords[MAX_RANK]; - for (uint64_t i = tid; i < zLen; i += totalThreads) { + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { shape::index2coords(i, zShapeInfo, coords); const auto zOffset = shape::getOffset(zShapeInfo, coords); @@ -162,9 +162,9 @@ void concat(sd::LaunchContext * context, const std::vector& inAr // } // else { // general (slower) case - const int threadsPerBlock = 256; - const int blocksPerGrid = 512; - const int sharedMem = 512; + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = 256; // prepare arrays of pointers on buffers and shapes std::vector hInBuffers(numOfInArrs); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu index 6e70d4510..8c30e510f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu @@ -88,7 +88,7 @@ namespace helpers { void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { NDArray::prepareSpecialUse({values, indices}, {input}); auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType(); - BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({values, indices}, {input}); } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu b/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu new file mode 100644 index 000000000..bb7998e60 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu @@ -0,0 +1,228 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// implemented algorithm is GPU adaptation of algorithm described in following article: +// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167 +// + +#include +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +template +static __global__ void fisherYatesCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power) { + + T* x = reinterpret_cast(vx); + + __shared__ T* shmem, temp; + __shared__ Nd4jLong ind, blockOffset, lenPerBlock; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char sharedMemory[]; + shmem = reinterpret_cast(sharedMemory); + + blockOffset = (len * blockIdx.x) >> power; + lenPerBlock = ((len * (blockIdx.x + 1)) >> power) - blockOffset; + ind = blockOffset; + } + __syncthreads(); + + // copy from global memory to shared memory + if(threadIdx.x < lenPerBlock) + shmem[threadIdx.x] = x[(blockOffset + threadIdx.x) * ews]; + __syncthreads(); + + // *** apply Fisher-Yates shuffle to lenPerBlock number of elements + if (threadIdx.x == 0) { + for(Nd4jLong i = lenPerBlock - 1; i > 0; --i) { + const Nd4jLong j = rng->relativeLong(ind++) % (i + 1); + if(i != j) { + temp = shmem[i]; + shmem[i] = shmem[j]; + shmem[j] = temp; + } + } + } + __syncthreads(); + + // copy from shared memory to global memory + if(threadIdx.x < lenPerBlock) + x[(blockOffset + threadIdx.x) * ews] = shmem[threadIdx.x]; +} + +template +static __global__ void mergeShuffleCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power, const Nd4jLong iterNum) { + + + T* x = reinterpret_cast(vx); + + __shared__ Nd4jLong ind, blockOffset, factor, beg, mid, totLen, iterExp; + + // *** apply mergeShuffle algorithm + if(threadIdx.x == 0) { + + factor = blockIdx.x << iterNum; + iterExp = 1 << (iterNum - 1); + blockOffset = (len * factor) >> power; + mid = ((len * (factor + iterExp)) >> power) - blockOffset; // middle + totLen = ((len * (factor + 2*iterExp)) >> power) - blockOffset; + ind = iterNum * len + blockOffset; + beg = 0; // beginning + + // printf("m %lld, blockIdx.x %lld, factor %lld, blockOffset %lld, mid %lld, totLen %lld \n", m,k,factor,blockOffset,mid,totLen); + + while (true) { + if(rng->relativeLong(ind++) % 2) { + if(mid == totLen) + break; + math::nd4j_swap(x[(blockOffset + beg) * ews], x[(blockOffset + mid++) * ews]); + } else { + if(beg == mid) + break; + } + ++beg; + } + + // Fisher-Yates + while (beg < totLen) { + const Nd4jLong e = rng->relativeLong(ind++) % (beg + 1); + if(beg != e) + math::nd4j_swap(x[(blockOffset + beg) * ews], x[(blockOffset + e) * ews]); + ++beg; + } + } +} + + +////////////////////////////////////////////////////////////////////////// +// Fisher-Yates shuffle +template +static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) { + + for(Nd4jLong i = len-1; i > 0; --i) { + const Nd4jLong j = rng.relativeLong(ind++) % (i + 1); + if(i != j) + math::nd4j_swap(buff[i*ews], buff[j*ews]); + } +} + +////////////////////////////////////////////////////////////////////////// +template +static void randomShuffle_(sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { + + const int firstDim = input.sizeAt(0); + int temp; + + if(input.lengthOf() == 1 || firstDim == 1) { + + if(!isInplace) + output.assign(input); + } + else if (shape::isCommonVector(input.shapeInfo(), temp)) { + + NDArray* arr = &input; + + if (!isInplace) { + output.assign(input); + arr = &output; + } + + const Nd4jLong len = arr->lengthOf(); + + const int threadsPerBlock = MAX_NUM_THREADS; + + int power = 0; + while ((len >> power) > threadsPerBlock) + ++power; + + const int blocksPerGrid = 1 << power; + const int sharedMem = threadsPerBlock * input.sizeOfT() + 256; + + PointersManager manager(context, "NDArray::randomShuffle cuda"); + + sd::graph::RandomGenerator* pRng = reinterpret_cast(manager.replicatePointer(&rng, sizeof(sd::graph::RandomGenerator))); + + NDArray::prepareSpecialUse({arr}, {arr}); + fisherYatesCuda<<getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power); + for (Nd4jLong j = 1, i = 1; j < blocksPerGrid; j += j, ++i) + mergeShuffleCuda<<getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power, i); + NDArray::registerSpecialUse({arr}, {arr}); + + manager.synchronize(); + + rng.rewindH((len + 1) * power); + } + else { + + auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); + + if(isInplace) { + + auto subArrsList = input.allTensorsAlongDimension(dimsToExclude); + + // Fisher-Yates shuffle + for(int i = firstDim - 1; i > 0; --i) { + const int j = rng.relativeInt(i) % (i + 1); + if(i != j) + subArrsList.at(i)->swapUnsafe(*subArrsList.at(j)); + } + } + else { + + auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude); + auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude); + + std::vector indices(firstDim); + std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1 + + // shuffle indices + fisherYates(rng, indices.data(), firstDim, 1, 0); + + auto func = PRAGMA_THREADS_FOR { + + for (auto i = start; i < stop; ++i) + subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i])); + }; + + samediff::Threads::parallel_for(func, 0, firstDim); + } + + rng.rewindH(firstDim-1); + } +} + +///////////////////////////////////////////////////////////////////////// +void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { + BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES); +} + +// BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES); + + + +} +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index 8d7f700dd..80e0e0858 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -300,129 +300,6 @@ void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray manager.synchronize(); } - template - static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) { - auto tid = blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) { - int r = rng->relativeInt(i) % i; - if (i != r) { - const auto iOffset = shape::getIndexOffset(i, shape); - const auto rOffset = shape::getIndexOffset(r, shape); - T e0 = input[iOffset]; - T e1 = input[rOffset]; - //math::nd4j_swap(input(i), input(r)); - input[iOffset] = e1; - input[rOffset] = e0; - } - } - } - template - static __global__ void fillShuffleKernel(T* input, Nd4jLong const* inputShape, T* output, Nd4jLong const* outputShape, Nd4jLong firstDim, int* indices, sd::graph::RandomGenerator* rng) { - -// PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold()) - auto tid = blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for(int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) { - int r = rng->relativeInt(i) % i; - output[shape::getIndexOffset(i, outputShape)] = input[shape::getIndexOffset(indices[r], inputShape)]; - if(i != r) { - output[shape::getIndexOffset(r, outputShape)] = input[shape::getIndexOffset(indices[i], inputShape)]; -// output.p(r, input.e(indices[i])); -// math::nd4j_swap(indices[i], indices[r]); - atomicExch(&indices[i], indices[r]); - } - } - - } - ////////////////////////////////////////////////////////////////////////// - template - void randomShuffle_(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - - // check edge cases first - int temp; - const int firstDim = input.sizeAt(0); - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({&output}, {&input}); - if(input.lengthOf() == 1 || firstDim == 1) { - if(!isInplace) - output.assign(input); - } - else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) { - - // apply Fisher-Yates shuffle - sd::graph::RandomGenerator* dRandom = nullptr; - cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator)); - cudaMemcpy(dRandom, &rng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice); - T* inputBuf = reinterpret_cast(input.specialBuffer()); - if(isInplace) { - swapShuffleKernel<<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), firstDim, dRandom); - } - else { - std::vector indices(firstDim); - std::iota(indices.begin(), indices.end(), 0); - cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), cudaMemcpyDeviceToDevice); - //output.p(Nd4jLong(0), input.e(0)); - PointersManager pointersManager(context, "helper::randomShuffle_"); - int* indicesDev = reinterpret_cast(pointersManager.replicatePointer(indices.data(), indices.size() * sizeof(int))); - T* outputBuf = reinterpret_cast(output.specialBuffer()); - fillShuffleKernel<<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), outputBuf, output.specialShapeInfo(), firstDim, indicesDev, dRandom); - pointersManager.synchronize(); - } -// rng.rewindH(firstDim - 1); - cudaFree(dRandom); - } - else { - - // evaluate sub-arrays list of input array through all dimensions excluding first one - std::vector dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); - auto subArrsListIn = input.allTensorsAlongDimension(dimensions); - - // apply Fisher-Yates shuffle - if(isInplace) { - for(int i = firstDim - 1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - - if(i != r) - subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r)); - } - } - else { - // evaluate sub-arrays list of output array through all dimensions excluding first one - auto subArrsListOut = output.allTensorsAlongDimension(dimensions); - std::vector indices(firstDim); - std::iota(indices.begin(), indices.end(), 0); - bool isZeroShuffled = false; - - for(int i = firstDim - 1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r])); - if(r == 0) - isZeroShuffled = true; - - if(i != r) { - subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i])); - math::nd4j_swap(indices[i], indices[r]); - } - } - if(!isZeroShuffled) - subArrsListOut.at(0)->assign(subArrsListIn.at(0)); - } - rng.rewindH(firstDim-1); - } - NDArray::registerSpecialUse({&output}, {&input}); - - } - - void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES); - - ////////////////////////////////////////////////////////////////////////// void eye(sd::LaunchContext * context, NDArray& output) { diff --git a/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp new file mode 100644 index 000000000..66b472252 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp @@ -0,0 +1,278 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // Created by Abdelrauf 2020 + + +#include +#include +#include +#include +#include +#include + +#include "armcomputeUtils.h" + + +namespace sd { +namespace ops { +namespace platforms { + + + +Arm_DataType getArmType ( const DataType &dType){ + Arm_DataType ret; + switch (dType){ + case HALF : + ret = Arm_DataType::F16; + break; + case FLOAT32 : + ret = Arm_DataType::F32; + break; + case DOUBLE : + ret = Arm_DataType::F64; + break; + case INT8 : + ret = Arm_DataType::S8; + break; + case INT16 : + ret = Arm_DataType::S16; + break; + case INT32 : + ret = Arm_DataType::S32; + break; + case INT64 : + ret = Arm_DataType::S64; + break; + case UINT8 : + ret = Arm_DataType::U8; + break; + case UINT16 : + ret = Arm_DataType::U16; + break; + case UINT32 : + ret = Arm_DataType::U32; + break; + case UINT64 : + ret = Arm_DataType::U64; + break; + case BFLOAT16 : + ret = Arm_DataType::BFLOAT16; + break; + default: + ret = Arm_DataType::UNKNOWN; + }; + + return ret; +} +bool isArmcomputeFriendly(const NDArray& arr) { + auto dType = getArmType(arr.dataType()); + int rank = (int)(arr.rankOf()); + return dType != Arm_DataType::UNKNOWN && + rank<=arm_compute::MAX_DIMS && + arr.ordering() == 'c' && + arr.ews()==1 && + shape::strideDescendingCAscendingF(arr.shapeInfo()) == true; +} + +Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases,sd::DataType ndArrayType, arm_compute::DataLayout layout) { + constexpr int numChannels = 1; + auto dType = getArmType(ndArrayType); + + Arm_TensorShape shape; + shape.set_num_dimensions(rank); + for (int i = 0, j = rank - 1; i < rank; i++, j--) { + shape[i] = static_cast(bases[j]); + } + // fill the rest unused with 1 + for (int i = rank; i < arm_compute::MAX_DIMS; i++) { + shape[i] = 1; + } + + return Arm_TensorInfo(shape, numChannels, dType, layout); +} + +Arm_TensorInfo getArmTensorInfo(const NDArray& arr, + arm_compute::DataLayout layout) { + auto dType = getArmType(arr.dataType()); + + // + constexpr int numChannels = 1; + int rank = (int)(arr.rankOf()); + auto bases = arr.shapeOf(); + auto arrStrides = arr.stridesOf(); + + // https://arm-software.github.io/ComputeLibrary/v20.05/_dimensions_8h_source.xhtml + // note: underhood it is stored as std::array _id; + // TensorShape is derived from Dimensions + // as well as Strides : public Dimensions + Arm_TensorShape shape; + Arm_Strides strides; + shape.set_num_dimensions(rank); + strides.set_num_dimensions(rank); + size_t element_size = arm_compute::data_size_from_type(dType); + for (int i = 0, j = rank - 1; i < rank; i++, j--) { + shape[i] = static_cast(bases[j]); + strides[i] = static_cast(arrStrides[j]) * element_size; + } + // fill the rest unused with 1 + for (int i = rank; i < arm_compute::MAX_DIMS; i++) { + shape[i] = 1; + } + size_t total_size; + size_t size_ind = rank - 1; + total_size = shape[size_ind] * strides[size_ind]; + + Arm_TensorInfo info; + info.init(shape, numChannels, dType, strides, 0, total_size); + info.set_data_layout(layout); + + return info; +} + +Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout) { + // - Ownership of the backing memory is not transferred to the tensor itself. + // - The tensor mustn't be memory managed. + // - Padding requirements should be accounted by the client code. + // In other words, if padding is required by the tensor after the function + // configuration step, then the imported backing memory should account for it. + // Padding can be checked through the TensorInfo::padding() interface. + + // Import existing pointer as backing memory + auto info = getArmTensorInfo(arr, layout); + Arm_Tensor tensor; + tensor.allocator()->init(info); + void* buff = (void*)arr.buffer(); + tensor.allocator()->import_memory(buff); + return tensor; +} + +void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output) { + //only for C order + //only for C order + if (output.ordering() != 'c') return; + auto shapeInfo = output.shapeInfo(); + auto bases = &(shapeInfo[1]); + Nd4jLong rank = shapeInfo[0]; + auto strides = output.stridesOf(); + int width = bases[rank - 1]; + uint8_t* outputBuffer = (uint8_t*)output.buffer(); + size_t offset = 0; + arm_compute::Window window; + arm_compute::Iterator tensor_it(&inTensor, window); + + int element_size = inTensor.info()->element_size(); + window.use_tensor_dimensions(inTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY); + +// if (output.ews() == 1) { + auto copySize = width * element_size; + auto dest = outputBuffer; + arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id) + { + auto src = tensor_it.ptr(); + memcpy(dest, src, copySize); + dest += copySize; + }, + tensor_it); + // } + // else { + // Nd4jLong coords[MAX_RANK] = {}; + // if(strides[rank-1]!=1){ + // throw std::runtime_error( "not implemented for subarrays whose last stride is not 1"); + // //TODO: implement to work with all subarrays properly + // } + // arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id) + // { + // auto src = tensor_it.ptr(); + // auto dest = outputBuffer + offset * element_size; + // memcpy(dest, src, width * element_size); + // offset = sd::inc_coords(bases, strides, coords, offset, rank, 1); + // }, + // tensor_it); + // } +} + +void copyToTensor(const NDArray& input, Arm_Tensor& outTensor) { + //only for C order + if (input.ordering() != 'c') return; + auto shapeInfo = input.shapeInfo(); + auto bases = &(shapeInfo[1]); + Nd4jLong rank = shapeInfo[0]; + auto strides = input.stridesOf(); + uint8_t *inputBuffer = (uint8_t*)input.buffer(); + int width = bases[rank - 1]; + size_t offset = 0; + arm_compute::Window window; + arm_compute::Iterator tensor_it(&outTensor, window); + int element_size = outTensor.info()->element_size(); + + window.use_tensor_dimensions(outTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY); + +// if (input.ews() == 1) { + + auto copySize = width * element_size; + auto src = inputBuffer; + arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id) + { + auto dest = tensor_it.ptr(); + memcpy(dest,src, copySize); + src += copySize; + }, + tensor_it); +// } +// else { +// Nd4jLong coords[MAX_RANK] = {}; +// if(strides[rank-1]!=1){ +// throw std::runtime_error( "not implemented for subarrays whose last stride is not 1"); +// //TODO: implement to work with all subarrays properly +// } +// arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id) +// { +// auto dest = tensor_it.ptr(); +// auto src = inputBuffer + offset * element_size; +// offset = sd::inc_coords(bases, strides, coords, offset, rank, 1); +// }, +// tensor_it); +// } +} + + +// armcompute should be built with debug option +void print_tensor(Arm_ITensor& tensor, const char* msg) { + auto info = tensor.info(); + auto padding = info->padding(); + std::cout << msg << "\ntotal: " << info->total_size() << "\n"; + + for (int i = 0; i < arm_compute::MAX_DIMS; i++) { + std::cout << info->dimension(i) << ","; + } + std::cout << std::endl; + for (int i = 0; i < arm_compute::MAX_DIMS; i++) { + std::cout << info->strides_in_bytes()[i] << ","; + } + std::cout << "\npadding: l " << padding.left << ", r " << padding.right + << ", t " << padding.top << ", b " << padding.bottom << std::endl; + +#ifdef ARM_COMPUTE_ASSERTS_ENABLED + //note it did not print correctly fro NHWC + std::cout << msg << ":\n"; + tensor.print(std::cout); + std::cout << std::endl; +#endif +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h new file mode 100644 index 000000000..72a4e6e89 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h @@ -0,0 +1,133 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +#ifndef DEV_TESTSARMCOMPUTEUTILS_H +#define DEV_TESTSARMCOMPUTEUTILS_H + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace samediff; + + +namespace sd { + namespace ops { + namespace platforms { + + using Arm_DataType = arm_compute::DataType; + using Arm_Tensor = arm_compute::Tensor; + using Arm_ITensor = arm_compute::ITensor; + using Arm_TensorInfo = arm_compute::TensorInfo; + using Arm_TensorShape = arm_compute::TensorShape; + using Arm_Strides = arm_compute::Strides; + /** + * Here we actually declare our platform helpers + */ + + + DECLARE_PLATFORM(maxpool2d, ENGINE_CPU); + + DECLARE_PLATFORM(avgpool2d, ENGINE_CPU); + + //utils + Arm_DataType getArmType(const sd::DataType& dType); + + Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases, sd::DataType ndArrayType, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN); + + Arm_TensorInfo getArmTensorInfo(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN); + + Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN); + + void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output); + void copyToTensor(const NDArray& input, Arm_Tensor& outTensor); + void print_tensor(Arm_ITensor& tensor, const char* msg); + bool isArmcomputeFriendly(const NDArray& arr); + + + template + class ArmFunction { + public: + + template + void configure(NDArray *input , NDArray *output, arm_compute::DataLayout layout, Args&& ...args) { + + auto inInfo = getArmTensorInfo(*input, layout); + auto outInfo = getArmTensorInfo(*output, layout); + in.allocator()->init(inInfo); + out.allocator()->init(outInfo); + armFunction.configure(&in,&out,std::forward(args) ...); + if (in.info()->has_padding()) { + //allocate and copy + in.allocator()->allocate(); + //copy + copyToTensor(*input, in); + + } + else { + //import buffer + void* buff = input->buffer(); + in.allocator()->import_memory(buff); + } + if (out.info()->has_padding()) { + //store pointer to our array to copy after run + out.allocator()->allocate(); + outNd = output; + } + else { + //import + void* buff = output->buffer(); + out.allocator()->import_memory(buff); + } + + } + + void run() { + armFunction.run(); + if (outNd) { + copyFromTensor(out, *outNd); + } + } + + private: + Arm_Tensor in; + Arm_Tensor out; + NDArray *outNd=nullptr; + F armFunction{}; + }; + } + } +} + + + +#endif //DEV_TESTSARMCOMPUTEUTILS_H diff --git a/libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp new file mode 100644 index 000000000..d8413104d --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp @@ -0,0 +1,106 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // Created by Abdelrauf (rauf@konduit.ai) 2020 + +#include +#include +#include +#include + + +#include "armcomputeUtils.h" + + +namespace sd { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = INT_ARG(8); + const auto extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D ARMCOMPUTE op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D ARMCOMPUTE op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + bool exclude_padding= (extraParam0 == 0) ? true : false; + + auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC; + + // Calculate individual paddings + unsigned int pad_left, pad_top, pad_right, pad_bottom; + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + if(paddingMode){ + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + } + pad_left = pW; + pad_top = pH; + pad_right = (oW - 1) * sW - iW + kW - pW ; + pad_bottom = (oH - 1) * sH - iH + kH - pH ; + +#if 0 + nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH + , pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0); +#endif + auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR); + auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::AVG, arm_compute::Size2D(kW, kH), dataLayout, poolPad, exclude_padding); + ArmFunction pool; + pool.configure(input,output, dataLayout, poolInfo); + + pool.run(); // run function + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + // Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + auto dTypeInput = getArmType(input->dataType()); + auto dTypeOutput = getArmType(output->dataType()); + bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output) + && (dTypeInput ==Arm_DataType::F32) + && (dTypeOutput ==Arm_DataType::F32); + return is_supported; +} + + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp new file mode 100644 index 000000000..cd6779628 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp @@ -0,0 +1,106 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // Created by Abdelrauf 2020 + + +#include +#include +#include +#include + + +#include "armcomputeUtils.h" + + +namespace sd { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D ARMCOMPUTE OP: input array should have rank of 4, but got %i instead", input->rankOf()); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + const int kH = INT_ARG(0); + const int kW = INT_ARG(1); + const int sH = INT_ARG(2); + const int sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + const int paddingMode = INT_ARG(8); + // const int extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC; + + // Calculate individual paddings + unsigned int pad_left, pad_top, pad_right, pad_bottom; + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + if(paddingMode){ + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + } + pad_left = pW; + pad_top = pH; + pad_right = (oW - 1) * sW - iW + kW - pW ; + pad_bottom = (oH - 1) * sH - iH + kH - pH ; +#if 0 + nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH + , pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0); +#endif + + auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR); + auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::MAX, arm_compute::Size2D(kW, kH), dataLayout, poolPad); + ArmFunction pool; + + pool.configure(input,output, dataLayout, poolInfo); + + pool.run(); // run function + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + // Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + auto dTypeInput = getArmType(input->dataType()); + auto dTypeOutput = getArmType(output->dataType()); + bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output) + && (dTypeInput ==Arm_DataType::F32) + && (dTypeOutput ==Arm_DataType::F32); + return is_supported; +} + + + +} +} +} diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index ea52e9ba0..aca6fec6f 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -3963,9 +3963,6 @@ namespace simdOps { } #endif -#ifndef __clang__ -#pragma omp declare simd uniform(extraParamsRef) -#endif op_def static Y merge(X old, X opOutput, X *extraParamsRef) { return update(old, opOutput, extraParamsRef); } diff --git a/libnd4j/pi_build.sh b/libnd4j/pi_build.sh new file mode 100755 index 000000000..f96c3f1f1 --- /dev/null +++ b/libnd4j/pi_build.sh @@ -0,0 +1,185 @@ +#!/bin/bash +TARGET=armv7-a +BLAS_TARGET_NAME=ARMV7 +ARMCOMPUTE_TARGET=armv7a +#BASE_DIR=${HOME}/pi +#https://stackoverflow.com/questions/59895/how-to-get-the-source-directory-of-a-bash-script-from-within-the-script-itself +SOURCE="${BASH_SOURCE[0]}" +ARMCOMPUTE_DEBUG=1 +LIBND4J_BUILD_MODE=Release +while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink + DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )" + SOURCE="$(readlink "$SOURCE")" + [[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located +done +BASE_DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )" +CMAKE=cmake #/snap/bin/cmake + +mkdir -p ${BASE_DIR}/helper_bin/ + +CROSS_COMPILER_URL=https://sourceforge.net/projects/raspberry-pi-cross-compilers/files/Raspberry%20Pi%20GCC%20Cross-Compiler%20Toolchains/Buster/GCC%208.3.0/Raspberry%20Pi%203A%2B%2C%203B%2B%2C%204/cross-gcc-8.3.0-pi_3%2B.tar.gz/download +CROSS_COMPILER_DIR=${BASE_DIR}/helper_bin/cross_compiler + +SCONS_LOCAL_URL=http://prdownloads.sourceforge.net/scons/scons-local-3.1.1.tar.gz +SCONS_LOCAL_DIR=${BASE_DIR}/helper_bin/scons_local + +THIRD_PARTY=${BASE_DIR}/third_party_libs + +ARMCOMPUTE_GIT_URL=https://github.com/ARM-software/ComputeLibrary.git +ARMCOMPUTE_TAG=v20.05 +ARMCOMPUTE_DIR=${THIRD_PARTY}/arm_compute_dir + +OPENBLAS_GIT_URL="https://github.com/xianyi/OpenBLAS.git" +OPENBLAS_DIR=${THIRD_PARTY}/OpenBLAS + + +LIBND4J_SRC_DIR=${BASE_DIR} + +LIBND4J_BUILD_DIR=${BASE_DIR}/build_pi + +#for some downloads +XRTACT_STRIP="--strip-components=1" + +HAS_ARMCOMPUTE=1 +mkdir -p ${BASE_DIR} +mkdir -p ${THIRD_PARTY} + +#change directory to base +cd $BASE_DIR + +function message { + echo "BUILDER:::: ${@}" +} + + +function check_requirements { + for i in "${@}" + do + if [ ! -e "$i" ]; then + message "missing: ${i}" + exit -2 + fi + done +} + +function download_extract { + #$1 is url #2 is dir $3 is extract argument + if [ ! -f ${2}_file ]; then + message "download" + wget --quiet --show-progress -O ${2}_file ${1} + fi + + message "extract" + #extract + mkdir -p ${2} + command="tar -xzf ${2}_file --directory=${2} ${3} " + message $command + $command + + check_requirements "${2}" +} + +function git_check { + #$1 is url #$2 is dir #$3 is tag or branch if optional + command="git clone --quiet ${1} ${2}" + message "$command" + $command + if [ -n "$3" ]; then + cd ${2} + command="git checkout ${3}" + message "$command" + $command + cd ${BASE_DIR} + fi + check_requirements "${2}" +} + + +if [ ! -d ${CROSS_COMPILER_DIR} ]; then + #out file + message "download CROSS_COMPILER" + download_extract ${CROSS_COMPILER_URL} ${CROSS_COMPILER_DIR} ${XRTACT_STRIP} +fi + +#useful exports +export PI_FOLDER=${CROSS_COMPILER_DIR} +export RPI_BIN=${PI_FOLDER}/bin/arm-linux-gnueabihf +export PI_SYS_ROOT=${PI_FOLDER}/arm-linux-gnueabihf/libc +export LD_LIBRARY_PATH=${PI_FOLDER}/lib:$LD_LIBRARY_PATH +export CC=${RPI_BIN}-gcc +export FC=${RPI_BIN}-gfortran +export CXX=${RPI_BIN}-g++ +export CPP=${RPI_BIN}-cpp +export RANLIB=${RPI_BIN}-gcc-ranlib +export LD="${RPI_BIN}-ld" +export AR="${RPI_BIN}-ar" + + +#lets build OpenBlas +if [ ! -d "${OPENBLAS_DIR}" ]; then + message "download OpenBLAS" + git_check "${OPENBLAS_GIT_URL}" "${OPENBLAS_DIR}" +fi + +if [ ! -f "${THIRD_PARTY}/lib/libopenblas.so" ]; then + message "build and install OpenBLAS" + cd ${OPENBLAS_DIR} + + command="make TARGET=${BLAS_TARGET_NAME} HOSTCC=gcc CC=${CC} USE_THREAD=0 NOFORTRAN=1 CFLAGS=--sysroot=${PI_SYS_ROOT} LDFLAGS=\"-L${PI_SYS_ROOT}/../lib/ -lm\" &>/dev/null" + message $command + eval $command + message "install it" + command="make PREFIX=${THIRD_PARTY} install" + message $command + $command + cd $BASE_DIR + +fi +check_requirements ${THIRD_PARTY}/lib/libopenblas.so + + + +if [ ! -d ${SCONS_LOCAL_DIR} ]; then + #out file + message "download Scons local" + download_extract ${SCONS_LOCAL_URL} ${SCONS_LOCAL_DIR} +fi +check_requirements ${SCONS_LOCAL_DIR}/scons.py + + +if [ ! -d "${ARMCOMPUTE_DIR}" ]; then + message "download ArmCompute Source" + git_check ${ARMCOMPUTE_GIT_URL} "${ARMCOMPUTE_DIR}" "tags/${ARMCOMPUTE_TAG}" +fi + +#build armcompute +if [ ! -f "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" ]; then +message "build arm compute" +cd ${ARMCOMPUTE_DIR} +command="CC=gcc CXX=g++ python3 ${SCONS_LOCAL_DIR}/scons.py Werror=1 -j$(nproc) toolchain_prefix=${RPI_BIN}- debug=${ARMCOMPUTE_DEBUG} neon=1 opencl=0 extra_cxx_flags=-fPIC os=linux build=cross_compile arch=${ARMCOMPUTE_TARGET} &>/dev/null" +message $command +eval $command +cd ${BASE_DIR} +fi +check_requirements "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" "${ARMCOMPUTE_DIR}/build/libarm_compute_core-static.a" + + + +message "build cmake for LIBND4J. output: ${LIBND4J_BUILD_DIR}" + +TOOLCHAIN=${LIBND4J_SRC_DIR}/cmake/rpi.cmake +cmake_cmd="${CMAKE} -G \"Unix Makefiles\" -B${LIBND4J_BUILD_DIR} -S${LIBND4J_SRC_DIR} -DCMAKE_BUILD_TYPE=${LIBND4J_BUILD_MODE} -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN} -DCMAKE_VERBOSE_MAKEFILE:BOOL=ON -DSD_ALL_OPS=true -DSD_CPU=true -DSD_LIBRARY_NAME=nd4jcpu -DSD_BUILD_TESTS=ON -DSD_ARM_BUILD=true -DOPENBLAS_PATH=${THIRD_PARTY} -DSD_ARCH=${TARGET} -DARMCOMPUTE_ROOT=${ARMCOMPUTE_DIR} -DHELPERS_armcompute=${HAS_ARMCOMPUTE}" +message $cmake_cmd +eval $cmake_cmd + +#build +message "lets build" + +cd ${LIBND4J_BUILD_DIR} +make -j $(nproc) + + + + + + diff --git a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt index 563bf58f6..9478f6fe2 100644 --- a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt @@ -52,14 +52,19 @@ elseif(WIN32) set(CMAKE_CXX_FLAGS " -fPIC") endif() else() - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") set(CMAKE_CXX_FLAGS " -fPIC") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") + IF(${SD_ARCH} MATCHES "arm*") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=${SD_ARCH}") + else() + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*") set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native") else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native") endif() - + endif() if (SD_CPU AND SD_SANITIZE) set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address") else() @@ -130,7 +135,7 @@ if (SD_CPU) endif() add_executable(runtests ${TEST_SOURCES}) - target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} gtest gtest_main) + target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} ${ARMCOMPUTE_LIBRARIES} gtest gtest_main) elseif(SD_CUDA) add_executable(runtests ${TEST_SOURCES}) diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 169c51124..39277cd87 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -1113,7 +1113,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto z = result.at(0); - +#if 0 + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1132,7 +1135,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto z = result.at(0); - +#if 0 + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1151,7 +1157,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto z = result.at(0); - +#if 0 + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); } @@ -1204,7 +1213,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) { auto* output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); - +#if 0 + expOutput.printIndexedBuffer("expOutput"); + output->printIndexedBuffer("output"); +#endif ASSERT_TRUE(expOutput.isSameShape(output)); ASSERT_TRUE(expOutput.equalsTo(output)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index 5f1aefe36..beccc1aae 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -244,7 +244,8 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) { #ifdef _RELEASE TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { // [2,1,135079944,1,1,8192,1,99] - auto initial = NDArrayFactory::create('c', {1, 135079944}); + constexpr int sizeX= 10*1000*1000; + auto initial = NDArrayFactory::create('c', {1, sizeX}); initial = 1.0f; auto exp = initial.dup(); auto neg = initial.like(); @@ -254,7 +255,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { auto enc_result = enc.evaluate({&initial}, {0.5f}); auto encoded = enc_result.at(1); - ASSERT_EQ(135079944 + 4, encoded->lengthOf()); + ASSERT_EQ(sizeX + 4, encoded->lengthOf()); ASSERT_NE(exp, initial); /* for (int e = 0; e < initial.lengthOf(); e++) { @@ -419,3 +420,4 @@ TEST_F(DeclarableOpsTests19, test_squeeze_1) { auto status = op.execute({&x}, {&e}, {axis}); ASSERT_EQ(Status::OK(), status); } + diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 04bb54a61..c68392da1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -1557,8 +1557,6 @@ TEST_F(DeclarableOpsTests5, trace_test1) { // exp.printIndexedBuffer("EXP TRACE"); // output->printIndexedBuffer("OUT TRACE"); ASSERT_TRUE(exp.equalsTo(output)); - - } ////////////////////////////////////////////////////////////////////// @@ -1575,8 +1573,6 @@ TEST_F(DeclarableOpsTests5, trace_test2) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - - } ////////////////////////////////////////////////////////////////////// @@ -1593,8 +1589,6 @@ TEST_F(DeclarableOpsTests5, trace_test3) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - - } ////////////////////////////////////////////////////////////////////// @@ -1611,8 +1605,6 @@ TEST_F(DeclarableOpsTests5, trace_test4) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - - } ////////////////////////////////////////////////////////////////////// @@ -1629,8 +1621,6 @@ TEST_F(DeclarableOpsTests5, trace_test5) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - - } ////////////////////////////////////////////////////////////////////// @@ -1638,22 +1628,15 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) { auto input = NDArrayFactory::create('c', {2, 2, 2}); input.linspace(1); + NDArray exp1 = input.dup(); + NDArray exp2('c',{2,2,2}, {5,6,7,8, 1,2,3,4}, sd::DataType::DOUBLE); sd::ops::random_shuffle op; auto results = op.evaluate({&input}); auto output = results.at(0); - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - - + ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2)); } ////////////////////////////////////////////////////////////////////// @@ -1661,16 +1644,14 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test2) { auto input = NDArrayFactory::create('c', {1, 3, 2}); input.linspace(1); + NDArray exp1 = input.dup(); sd::ops::random_shuffle op; auto results = op.evaluate({&input}); auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(input.equalsTo(output)); - - + ASSERT_TRUE(output->equalsTo(exp1)); } ////////////////////////////////////////////////////////////////////// @@ -1678,129 +1659,132 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) { auto input = NDArrayFactory::create('c', {3, 2, 1}); input.linspace(1); + NDArray exp1 = input.dup(); + NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE); + NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE); + NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE); + NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE); + NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE); sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); - auto output = results.at(0); - - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - - -} -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests5, random_shuffle_test04) { - auto input = NDArrayFactory::create('c', {4}); - input.linspace(1); - - sd::ops::random_shuffle op; - //NDArray* output; auto results = op.evaluate({&input}, {}, {}, {}, {}, true); + ASSERT_EQ(Status::OK(), results.status()); - auto output = &input; //results.at(0); - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; - - ASSERT_TRUE(input.isSameShape(output)); - //ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - - + ASSERT_TRUE(input.equalsTo(exp1) || input.equalsTo(exp2) || input.equalsTo(exp3) + || input.equalsTo(exp4) || input.equalsTo(exp5) || input.equalsTo(exp6)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test4) { - auto input = NDArrayFactory::create('c', {4}); + + auto input = NDArrayFactory::create('c', {3, 2, 1}); input.linspace(1); + NDArray exp1 = input.dup(); + NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE); + NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE); + NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE); + NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE); + NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE); sd::ops::random_shuffle op; - //NDArray* output; auto results = op.evaluate({&input}); - ASSERT_EQ(Status::OK(), results.status()); auto output = results.at(0); - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; - - ASSERT_TRUE(input.isSameShape(output)); - //ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2) || output->equalsTo(exp3) + || output->equalsTo(exp4) || output->equalsTo(exp5) || output->equalsTo(exp6)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test5) { - - auto input = NDArrayFactory::create('c', {4,1}); + auto input = NDArrayFactory::create('c', {4}); input.linspace(1); sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); auto output = results.at(0); - - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; + // output->printBuffer(); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - + // ASSERT_TRUE(!output->equalsTo(input)); + bool hasDublicates = false; + for(int i = 0; i < output->lengthOf() - 1; ++i) + for(int j = i+1; j < output->lengthOf(); ++j) + if(output->t(i) == output->t(j)) { + hasDublicates = true; + i = output->lengthOf(); + break; + } + ASSERT_TRUE(!hasDublicates); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test6) { - - auto input = NDArrayFactory::create('c', {4,1,1}); + auto input = NDArrayFactory::create('c', {4,1,1}); input.linspace(1); sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); auto output = results.at(0); - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - + // ASSERT_TRUE(!output->equalsTo(input)); + bool hasDublicates = false; + for(int i = 0; i < output->lengthOf() - 1; ++i) + for(int j = i+1; j < output->lengthOf(); ++j) + if(output->t(i) == output->t(j)) { + hasDublicates = true; + i = output->lengthOf(); + break; + } + ASSERT_TRUE(!hasDublicates); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test7) { - - auto input = NDArrayFactory::create('c', {1,4}); + auto input = NDArrayFactory::create('c', {16010}); input.linspace(1); - auto exp = NDArrayFactory::create('c', {1,4}, {1, 2, 3, 4}); sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); auto output = results.at(0); - + // output->printBuffer(); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(input.equalsTo(output)); + ASSERT_TRUE(!output->equalsTo(input)); + auto vec1 = input.getBufferAsVector(); + auto vec2 = output->getBufferAsVector(); + std::sort(vec2.begin(), vec2.end()); + ASSERT_TRUE(std::equal(vec1.begin(), vec1.end(), vec2.begin())); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, random_shuffle_test8) { + auto input = NDArrayFactory::create('c', {1,4,1}); + input.linspace(1); + NDArray inCopy = input.dup(); + + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(input.equalsTo(inCopy)); + +} + +TEST_F(DeclarableOpsTests5, random_shuffle_test9) { + + auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto z = x.ulike(); + + sd::ops::random_shuffle op; + auto status = op.execute({&x}, {&z}); + ASSERT_EQ(Status::OK(), status); + + auto vec = z.getBufferAsVector(); + std::sort(vec.begin(), vec.end()); + ASSERT_EQ(std::vector({1, 2, 3, 4}), vec); } //////////////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 949b43d25..f2bd393e4 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -251,11 +251,10 @@ TEST_F(DeclarableOpsTests9, concat_test1) { auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto output = result.at(0); + // output->printCurrentBuffer(false); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - - } //////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index a8f45cc48..e07a0496d 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -317,7 +317,7 @@ void fill_random(sd::NDArray& arr) { } } - + void testLegacy(bool random) { #if 0 int bases[] = { 3, 2, 4, 5, 7 }; @@ -364,7 +364,7 @@ int k = 4; #endif auto dim = NDArrayFactory::create(dimension); -#if 1 +#if 1 nd4j_printf("C(N:%d K:%d) \n", N, k); dim.printIndexedBuffer("Dimension"); for (int xind : dimension) { @@ -385,7 +385,7 @@ for (int e = 0; e < Loop; e++) { auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); values.emplace_back(outerTime); } - + std::sort(values.begin(), values.end()); nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); @@ -411,7 +411,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order =' constexpr int N = 5; #endif - + for (int i = 0; i < N; i++) { arr_dimensions.push_back(bases[i]); } @@ -451,7 +451,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order =' #endif auto dim = NDArrayFactory::create(dimension); -#if 1 +#if 1 nd4j_printf("C(N:%d K:%d) \n", N, k); dim.printIndexedBuffer("Dimension"); for (int xind : dimension) { @@ -477,14 +477,14 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order =' //check for the correctness NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create('c', output_bases) : NDArrayFactory::create(0); original_argmax(x, dimension, exp); - + #if 0// defined(DEBUG) x.printIndexedBuffer("X"); exp.printIndexedBuffer("Expected"); z->printIndexedBuffer("Z"); #endif - + ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); } @@ -505,7 +505,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfLinspace) { testNewReduction(false, test_corr); } #endif - + TEST_F(PlaygroundTests, ArgMaxPerfRandom) { testNewReduction(true, test_corr); } @@ -513,7 +513,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfRandom) { TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) { testNewReduction(true, test_corr, 'f'); } - + #if !defined(DEBUG) TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) { testLegacy(false); @@ -1062,39 +1062,6 @@ TEST_F(PlaygroundTests, my) { delete variableSpace; } -TEST_F(PlaygroundTests, my) { - - int N = 100; - int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=128,oW=128; - - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - // NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - // NDArray output('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); - // NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] - NDArray weights('c', {oC, iC, kH, kW}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, sd::DataType::FLOAT32); - - input = 5.; - weights = 3.; - bias = 1.; - - sd::ops::conv2d op; - auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - - auto timeStart = std::chrono::system_clock::now(); - for (int i = 0; i < N; ++i) - err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto timeEnd = std::chrono::system_clock::now(); - auto time = std::chrono::duration_cast ((timeEnd - timeStart) / N).count(); - - printf("time: %i \n", time); -} - /////////////////////////////////////////////////////////////////// TEST_F(PlaygroundTests, lstmLayerCellBp_1) { @@ -1690,6 +1657,52 @@ TEST_F(DeclarableOpsTests15, gru_bp_1) { const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); } +#include +////////////////////////////////////////////////////////////////////// +TEST_F(PlaygroundTests, my) { + + const int N = 10; + + NDArray input('c', {8000000}, sd::DataType::INT32); + input.linspace(1); + NDArray output = input.dup(); + + + sd::graph::RandomGenerator rng; + + sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true); + + // auto timeStart = std::chrono::system_clock::now(); + // for (int i = 0; i < N; ++i) + // sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true); + // auto timeEnd = std::chrono::system_clock::now(); + // auto time = std::chrono::duration_cast ((timeEnd - timeStart) / N).count(); + // printf("time: %i \n", time); + + // bool hasDublicates = false; + // for(int i = 0; i < output.lengthOf() - 1; ++i) + // for(int j = i+1; j < output.lengthOf(); ++j) + // if(output.t(i) == output.t(j)) { + // hasDublicates = true; + // i = output.lengthOf(); + // break; + // } + + ASSERT_TRUE(!input.equalsTo(output)); + + bool hasDublicates = false; + for(int i = 0; i < input.lengthOf() - 1; ++i) + for(int j = i+1; j < input.lengthOf(); ++j) + if(input.t(i) == input.t(j)) { + hasDublicates = true; + i = input.lengthOf(); + break; + } + ASSERT_TRUE(!hasDublicates); +} + + +} + */ - diff --git a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp b/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp deleted file mode 100644 index 8481dfde5..000000000 --- a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#ifndef LIBND4J_SESSIONLOCALTESTS_H -#define LIBND4J_SESSIONLOCALTESTS_H - -#include "testlayers.h" -#include -#include - -using namespace sd::graph; - -class SessionLocalTests : public testing::Test { -public: - -}; - -TEST_F(SessionLocalTests, BasicTests_1) { - VariableSpace variableSpace; - SessionLocalStorage storage(&variableSpace, nullptr); - - if (omp_get_max_threads() <= 1) - return; - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.startSession(); - } - - ASSERT_EQ(4, storage.numberOfSessions()); - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.endSession(); - } - - ASSERT_EQ(0, storage.numberOfSessions()); -} - - -TEST_F(SessionLocalTests, BasicTests_2) { - VariableSpace variableSpace; - SessionLocalStorage storage(&variableSpace, nullptr); - - if (omp_get_max_threads() <= 1) - return; - - auto alpha = sd::NDArrayFactory::create_('c',{5,5}); - alpha->assign(0.0); - - variableSpace.putVariable(-1, alpha); - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.startSession(); - - auto varSpace = storage.localVariableSpace(); - - auto arr = varSpace->getVariable(-1)->getNDArray(); - arr->applyScalar(sd::scalar::Add, (float) e+1, *arr); - } - - float lastValue = 0.0f; - for (int e = 1; e <= 4; e++) { - auto varSpace = storage.localVariableSpace((Nd4jLong) e); - - auto arr = varSpace->getVariable(-1)->getNDArray(); - - //nd4j_printf("Last value: %f; Current value: %f\n", lastValue, arr->e(0)); - - ASSERT_NE(lastValue, arr->e(0)); - lastValue = arr->e(0); - } -} - -#endif //LIBND4J_SESSIONLOCALTESTS_H diff --git a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt index 7e01e2847..bbd632d27 100644 --- a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt @@ -45,6 +45,21 @@ if ("${BUILD_MKLDNN}") set(MKLDNN dnnl) endif() +if (${HELPERS_armcompute}) + find_package(ARMCOMPUTE REQUIRED) + + if(ARMCOMPUTE_FOUND) + message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}") + set(HAVE_ARMCOMPUTE 1) + # Add preprocessor definition for ARM Compute NEON + add_definitions(-DARMCOMPUTENEON_ENABLED) + #build our library with neon support + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") + include_directories(${ARMCOMPUTE_INCLUDE}) + endif() + +endif() + # Download and unpack flatbuffers at configure time configure_file(../../CMakeLists.txt.in flatbuffers-download/CMakeLists.txt) execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . @@ -217,6 +232,10 @@ if ("${BUILD_MKLDNN}") file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../../include/ops/declarable/platform/mkldnn/*.cpp) endif() +if(HAVE_ARMCOMPUTE) + file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/armcomputeUtils.h) +endif() + message("CPU backend") add_definitions(-D__CPUBLAS__=true) @@ -276,8 +295,9 @@ endforeach(TMP_PATH) add_executable(runtests ${LOOPS_SOURCES} ${LEGACY_SOURCES} ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} - ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} + ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} + ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES}) -target_link_libraries(runtests gtest ${MKLDNN} gtest_main ${BLAS_LIBRARIES}) +target_link_libraries(runtests gtest ${MKLDNN} ${ARMCOMPUTE_LIBRARIES} gtest_main ${BLAS_LIBRARIES}) diff --git a/libnd4j/tests_cpu/resources/simpleif_0_alt.fb b/libnd4j/tests_cpu/resources/simpleif_0_alt.fb new file mode 100644 index 000000000..4a7e751c3 Binary files /dev/null and b/libnd4j/tests_cpu/resources/simpleif_0_alt.fb differ diff --git a/libnd4j/tests_cpu/resources/simplewhile_1.fb b/libnd4j/tests_cpu/resources/simplewhile_1.fb new file mode 100644 index 000000000..c4fa26e2a Binary files /dev/null and b/libnd4j/tests_cpu/resources/simplewhile_1.fb differ diff --git a/libnd4j/tests_cpu/resources/simplewhile_nested.fb b/libnd4j/tests_cpu/resources/simplewhile_nested.fb new file mode 100644 index 000000000..9404b98c4 Binary files /dev/null and b/libnd4j/tests_cpu/resources/simplewhile_nested.fb differ diff --git a/libnd4j/tests_cpu/resources/while_iter3.fb b/libnd4j/tests_cpu/resources/while_iter3.fb new file mode 100644 index 000000000..4b0e86979 Binary files /dev/null and b/libnd4j/tests_cpu/resources/while_iter3.fb differ diff --git a/python4j/pom.xml b/python4j/pom.xml index 57af8f1bb..3f1d026a5 100644 --- a/python4j/pom.xml +++ b/python4j/pom.xml @@ -25,7 +25,7 @@ 4.0.0 - org.eclipse + org.nd4j python4j-parent pom @@ -41,10 +41,14 @@ provided + org.slf4j + slf4j-api + 1.6.6 + ch.qos.logback logback-classic ${logback.version} - test + test junit @@ -62,5 +66,10 @@ jsr305 3.0.2 + + org.slf4j + slf4j-api + 1.6.6 + \ No newline at end of file diff --git a/python4j/python4j-core/pom.xml b/python4j/python4j-core/pom.xml index b429d8272..26e77b8d1 100644 --- a/python4j/python4j-core/pom.xml +++ b/python4j/python4j-core/pom.xml @@ -21,7 +21,7 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> python4j-parent - org.eclipse + org.nd4j 1.0.0-SNAPSHOT jar @@ -39,6 +39,5 @@ cpython-platform ${cpython-platform.version} - \ No newline at end of file diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/Python.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/Python.java similarity index 99% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/Python.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/Python.java index fd6fff112..03c2fdaab 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/Python.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/Python.java @@ -15,7 +15,7 @@ ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.bytedeco.cpython.PyObject; diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java similarity index 86% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java index a34d8a239..0090e38d4 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java @@ -14,13 +14,15 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import javax.lang.model.SourceVersion; +import java.io.Closeable; import java.util.HashSet; import java.util.Set; +import java.util.UUID; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -46,6 +48,31 @@ public class PythonContextManager { init(); } + + public static class Context implements Closeable{ + private final String name; + private final String previous; + private final boolean temp; + public Context(){ + name = "temp_" + UUID.randomUUID().toString().replace("-", "_"); + temp = true; + previous = getCurrentContext(); + setContext(name); + } + public Context(String name){ + this.name = name; + temp = false; + previous = getCurrentContext(); + setContext(name); + } + + @Override + public void close(){ + setContext(previous); + if (temp) deleteContext(name); + } + } + private static void init() { if (init.get()) return; new PythonExecutioner(); @@ -76,7 +103,18 @@ public class PythonContextManager { } private static boolean validateContextName(String s) { - return SourceVersion.isIdentifier(s) && !s.startsWith(COLLAPSED_KEY); + for (int i=0; i= '0' && c <= '9'){ + return false; + } + } + if (!(c=='_' || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9'))){ + return false; + } + } + return true; } private static String getContextPrefix(String contextName) { @@ -190,6 +228,7 @@ public class PythonContextManager { setContext(tempContext); deleteContext(currContext); setContext(currContext); + deleteContext(tempContext); } /** diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonException.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonException.java similarity index 98% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonException.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonException.java index a9bbf596c..e8f64f2be 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonException.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonException.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; /** diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java similarity index 90% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java index 57e1a22ae..bc48b0e98 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java @@ -15,7 +15,7 @@ ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.bytedeco.cpython.PyObject; @@ -42,7 +42,6 @@ public class PythonExecutioner { private final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.eclipse.python4j.path"; private final static String JAVACPP_PYTHON_APPEND_TYPE = "org.eclipse.python4j.path.append"; private final static String DEFAULT_APPEND_TYPE = "before"; - static { init(); } @@ -55,6 +54,11 @@ public class PythonExecutioner { initPythonPath(); PyEval_InitThreads(); Py_InitializeEx(0); + for (PythonType type: PythonTypes.get()){ + type.init(); + } + // Constructors of custom types may contain initialization code that should + // run on the main the thread. } /** @@ -110,6 +114,8 @@ public class PythonExecutioner { getVariables(Arrays.asList(pyVars)); } + + /** * Gets the variable with the given name from the interpreter. * @@ -205,9 +211,9 @@ public class PythonExecutioner { * * @return */ - public static List getAllVariables() { + public static PythonVariables getAllVariables() { PythonGIL.assertThreadSafe(); - List ret = new ArrayList<>(); + PythonVariables ret = new PythonVariables(); PyObject main = PyImport_ImportModule("__main__"); PyObject globals = PyModule_GetDict(main); PyObject keys = PyDict_Keys(globals); @@ -259,7 +265,7 @@ public class PythonExecutioner { * @param inputs * @return */ - public static List execAndReturnAllVariables(String code, List inputs) { + public static PythonVariables execAndReturnAllVariables(String code, List inputs) { setVariables(inputs); simpleExec(getWrappedCode(code)); return getAllVariables(); @@ -271,7 +277,7 @@ public class PythonExecutioner { * @param code * @return */ - public static List execAndReturnAllVariables(String code) { + public static PythonVariables execAndReturnAllVariables(String code) { simpleExec(getWrappedCode(code)); return getAllVariables(); } @@ -279,25 +285,22 @@ public class PythonExecutioner { private static synchronized void initPythonPath() { try { String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY); + + List packagesList = new ArrayList<>(); + packagesList.addAll(Arrays.asList(cachePackages())); + for (PythonType type: PythonTypes.get()){ + packagesList.addAll(Arrays.asList(type.packages())); + } + //// TODO: fix in javacpp + packagesList.add(new File(python.cachePackage(), "site-packages")); + + File[] packages = packagesList.toArray(new File[0]); + if (path == null) { - File[] packages = cachePackages(); - - //// TODO: fix in javacpp - File sitePackagesWindows = new File(python.cachePackage(), "site-packages"); - File[] packages2 = new File[packages.length + 1]; - for (int i = 0; i < packages.length; i++) { - //System.out.println(packages[i].getAbsolutePath()); - packages2[i] = packages[i]; - } - packages2[packages.length] = sitePackagesWindows; - //System.out.println(sitePackagesWindows.getAbsolutePath()); - packages = packages2; - ////////// - Py_SetPath(packages); } else { StringBuffer sb = new StringBuffer(); - File[] packages = cachePackages(); + JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase()); switch (pathAppendValue) { case BEFORE: diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGC.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGC.java similarity index 99% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGC.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGC.java index 5531b67d3..e18d2072d 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGC.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGC.java @@ -15,7 +15,7 @@ ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.bytedeco.cpython.PyObject; import org.bytedeco.javacpp.Pointer; diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGIL.java similarity index 96% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGIL.java index 46b3db431..3a88253e0 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGIL.java @@ -14,11 +14,10 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.bytedeco.cpython.PyThreadState; -import org.omg.SendingContext.RunTime; import java.util.concurrent.atomic.AtomicBoolean; @@ -90,4 +89,8 @@ public class PythonGIL implements AutoCloseable { PyEval_SaveThread(); PyEval_RestoreThread(mainThreadState); } + + public static boolean locked(){ + return acquired.get(); + } } diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonJob.java similarity index 93% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonJob.java index cdbb1b81d..f357388f7 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonJob.java @@ -14,31 +14,34 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import lombok.Builder; import lombok.Data; -import lombok.NoArgsConstructor; +import lombok.extern.slf4j.Slf4j; import javax.annotation.Nonnull; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; -@Data -@NoArgsConstructor /** * PythonJob is the right abstraction for executing multiple python scripts * in a multi thread stateful environment. The setup-and-run mode allows your * "setup" code (imports, model loading etc) to be executed only once. */ +@Data +@Slf4j public class PythonJob { + private String code; private String name; private String context; - private boolean setupRunMode; + private final boolean setupRunMode; private PythonObject runF; + private final AtomicBoolean setupDone = new AtomicBoolean(false); static { new PythonExecutioner(); @@ -63,7 +66,6 @@ public class PythonJob { if (PythonContextManager.hasContext(context)) { throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!"); } - if (setupRunMode) setup(); } @@ -71,17 +73,18 @@ public class PythonJob { * Clears all variables in current context and calls setup() */ public void clearState(){ - String context = this.context; - PythonContextManager.setContext("main"); - PythonContextManager.deleteContext(context); - this.context = context; + PythonContextManager.setContext(this.context); + PythonContextManager.reset(); + setupDone.set(false); setup(); } public void setup(){ + if (setupDone.get()) return; try (PythonGIL gil = PythonGIL.lock()) { PythonContextManager.setContext(context); PythonObject runF = PythonExecutioner.getVariable("run"); + if (runF == null || runF.isNone() || !Python.callable(runF)) { PythonExecutioner.exec(code); runF = PythonExecutioner.getVariable("run"); @@ -98,10 +101,12 @@ public class PythonJob { if (!setupF.isNone()) { setupF.call(); } + setupDone.set(true); } } public void exec(List inputs, List outputs) { + if (setupRunMode)setup(); try (PythonGIL gil = PythonGIL.lock()) { try (PythonGC _ = PythonGC.watch()) { PythonContextManager.setContext(context); @@ -139,6 +144,7 @@ public class PythonJob { } public List execAndReturnAllVariables(List inputs){ + if (setupRunMode)setup(); try (PythonGIL gil = PythonGIL.lock()) { try (PythonGC _ = PythonGC.watch()) { PythonContextManager.setContext(context); diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java similarity index 77% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java index f8ec17ed9..94b60d320 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.bytedeco.cpython.PyObject; @@ -147,7 +147,8 @@ public class PythonObject { } PythonObject pyArgs; PythonObject pyKwargs; - if (args == null) { + + if (args == null || args.isEmpty()) { pyArgs = new PythonObject(PyTuple_New(0)); } else { PythonObject argsList = PythonTypes.convert(args); @@ -158,6 +159,7 @@ public class PythonObject { } else { pyKwargs = PythonTypes.convert(kwargs); } + PythonObject ret = new PythonObject( PyObject_Call( nativePythonObject, @@ -165,7 +167,9 @@ public class PythonObject { pyKwargs == null ? null : pyKwargs.nativePythonObject ) ); + PythonGC.keep(ret); + return ret; } @@ -241,4 +245,48 @@ public class PythonObject { PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject); } + + public PythonObject abs(){ + return new PythonObject(PyNumber_Absolute(nativePythonObject)); + } + public PythonObject add(PythonObject pythonObject){ + return new PythonObject(PyNumber_Add(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject sub(PythonObject pythonObject){ + return new PythonObject(PyNumber_Subtract(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject mod(PythonObject pythonObject){ + return new PythonObject(PyNumber_Divmod(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject mul(PythonObject pythonObject){ + return new PythonObject(PyNumber_Multiply(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject trueDiv(PythonObject pythonObject){ + return new PythonObject(PyNumber_TrueDivide(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject floorDiv(PythonObject pythonObject){ + return new PythonObject(PyNumber_FloorDivide(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject matMul(PythonObject pythonObject){ + return new PythonObject(PyNumber_MatrixMultiply(nativePythonObject, pythonObject.nativePythonObject)); + } + + public void addi(PythonObject pythonObject){ + PyNumber_InPlaceAdd(nativePythonObject, pythonObject.nativePythonObject); + } + public void subi(PythonObject pythonObject){ + PyNumber_InPlaceSubtract(nativePythonObject, pythonObject.nativePythonObject); + } + public void muli(PythonObject pythonObject){ + PyNumber_InPlaceMultiply(nativePythonObject, pythonObject.nativePythonObject); + } + public void trueDivi(PythonObject pythonObject){ + PyNumber_InPlaceTrueDivide(nativePythonObject, pythonObject.nativePythonObject); + } + public void floorDivi(PythonObject pythonObject){ + PyNumber_InPlaceFloorDivide(nativePythonObject, pythonObject.nativePythonObject); + } + public void matMuli(PythonObject pythonObject){ + PyNumber_InPlaceMatrixMultiply(nativePythonObject, pythonObject.nativePythonObject); + } } diff --git a/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java new file mode 100644 index 000000000..bce8809f5 --- /dev/null +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java @@ -0,0 +1,127 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.nd4j.python4j; + +import org.apache.commons.io.IOUtils; +import org.bytedeco.javacpp.Loader; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +public class PythonProcess { + private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class); + public static String runAndReturn(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + ProcessBuilder pb = new ProcessBuilder(allArgs); + Process process = pb.start(); + String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8); + process.waitFor(); + return out; + + } + + public static void run(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + ProcessBuilder pb = new ProcessBuilder(allArgs); + pb.inheritIO().start().waitFor(); + } + public static void pipInstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "install", packageName); + }catch(Exception e){ + throw new PythonException("Error installing package " + packageName, e); + } + + } + + public static void pipInstall(String packageName, String version){ + pipInstall(packageName + "==" + version); + } + + public static void pipUninstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "uninstall", packageName); + }catch(Exception e){ + throw new PythonException("Error uninstalling package " + packageName, e); + } + + } + public static void pipInstallFromGit(String gitRepoUrl){ + if (!gitRepoUrl.contains("://")){ + gitRepoUrl = "git://" + gitRepoUrl; + } + try{ + run("-m", "pip", "install", "git+", gitRepoUrl); + }catch(Exception e){ + throw new PythonException("Error installing package from " + gitRepoUrl, e); + } + + } + + public static String getPackageVersion(String packageName){ + String out; + try{ + out = runAndReturn("-m", "pip", "show", packageName); + } catch (Exception e){ + throw new PythonException("Error finding version for package " + packageName, e); + } + + if (!out.contains("Version: ")){ + throw new PythonException("Can't find package " + packageName); + } + String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0]; + return pkgVersion; + } + + public static boolean isPackageInstalled(String packageName){ + try{ + String out = runAndReturn("-m", "pip", "show", packageName); + return !out.isEmpty(); + }catch (Exception e){ + throw new PythonException("Error checking if package is installed: " +packageName, e); + } + + } + + public static void pipInstallFromRequirementsTxt(String path){ + try{ + run("-m", "pip", "install","-r", path); + }catch (Exception e){ + throw new PythonException("Error installing packages from " + path, e); + } + } + + public static void pipInstallFromSetupScript(String path, boolean inplace){ + + try{ + run(path, inplace?"develop":"install"); + }catch (Exception e){ + throw new PythonException("Error installing package from " + path, e); + } + + } + +} \ No newline at end of file diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonType.java similarity index 72% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonType.java index b4806aa37..79b0ccaab 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonType.java @@ -14,9 +14,11 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; +import java.io.File; + public abstract class PythonType { private final String name; @@ -43,5 +45,25 @@ public abstract class PythonType { return name; } + @Override + public boolean equals(Object obj){ + if (!(obj instanceof PythonType)){ + return false; + } + PythonType other = (PythonType)obj; + return this.getClass().equals(other.getClass()) && this.name.equals(other.name); + } + + public PythonObject pythonType(){ + return null; + } + + public File[] packages(){ + return new File[0]; + } + + public void init(){ //not to be called from constructor + + } } diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java similarity index 58% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java index 0dc20f712..089c8aefe 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java @@ -14,11 +14,18 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.bytedeco.cpython.PyObject; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Pointer; +import sun.nio.ch.DirectBuffer; +import java.lang.reflect.Field; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.*; import static org.bytedeco.cpython.global.python.*; @@ -28,7 +35,7 @@ public class PythonTypes { private static List getPrimitiveTypes() { - return Arrays.asList(STR, INT, FLOAT, BOOL); + return Arrays.asList(STR, INT, FLOAT, BOOL, BYTES); } private static List getCollectionTypes() { @@ -36,8 +43,13 @@ public class PythonTypes { } private static List getExternalTypes() { - //TODO service loader - return new ArrayList<>(); + List ret = new ArrayList<>(); + ServiceLoader sl = ServiceLoader.load(PythonType.class); + Iterator iter = sl.iterator(); + while (iter.hasNext()) { + ret.add(iter.next()); + } + return ret; } public static List get() { @@ -48,15 +60,17 @@ public class PythonTypes { return ret; } - public static PythonType get(String name) { + public static PythonType get(String name) { for (PythonType pt : get()) { if (pt.getName().equals(name)) { // TODO use map instead? return pt; } + } throw new PythonException("Unknown python type: " + name); } + public static PythonType getPythonTypeForJavaObject(Object javaObject) { for (PythonType pt : get()) { if (pt.accepts(javaObject)) { @@ -66,7 +80,7 @@ public class PythonTypes { throw new PythonException("Unable to find python type for java type: " + javaObject.getClass()); } - public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) { + public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) { PyObject pyType = PyObject_Type(pythonObject.getNativePythonObject()); try { String pyTypeStr = PythonTypes.STR.toJava(new PythonObject(pyType, false)); @@ -75,6 +89,14 @@ public class PythonTypes { String pyTypeStr2 = ""; if (pyTypeStr.equals(pyTypeStr2)) { return pt; + } else { + try (PythonGC gc = PythonGC.watch()) { + PythonObject pyType2 = pt.pythonType(); + if (pyType2 != null && Python.isinstance(pythonObject, pyType2)) { + return pt; + } + } + } } throw new PythonException("Unable to find converter for python object of type " + pyTypeStr); @@ -212,12 +234,53 @@ public class PythonTypes { public static final PythonType LIST = new PythonType("list", List.class) { + @Override + public boolean accepts(Object javaObject) { + return (javaObject instanceof List || javaObject.getClass().isArray()); + } + @Override public List adapt(Object javaObject) { if (javaObject instanceof List) { return (List) javaObject; - } else if (javaObject instanceof Object[]) { - return Arrays.asList((Object[]) javaObject); + } else if (javaObject.getClass().isArray()) { + List ret = new ArrayList<>(); + if (javaObject instanceof Object[]) { + Object[] arr = (Object[]) javaObject; + return new ArrayList<>(Arrays.asList(arr)); + } else if (javaObject instanceof short[]) { + short[] arr = (short[]) javaObject; + for (short x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof int[]) { + int[] arr = (int[]) javaObject; + for (int x : arr) ret.add(x); + return ret; + }else if (javaObject instanceof byte[]){ + byte[] arr = (byte[]) javaObject; + for (int x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof long[]) { + long[] arr = (long[]) javaObject; + for (long x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof float[]) { + float[] arr = (float[]) javaObject; + for (float x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof double[]) { + double[] arr = (double[]) javaObject; + for (double x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof boolean[]) { + boolean[] arr = (boolean[]) javaObject; + for (boolean x : arr) ret.add(x); + return ret; + } else { + throw new PythonException("Unsupported array type: " + javaObject.getClass().toString()); + } + + } else { throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to List"); } @@ -327,7 +390,13 @@ public class PythonTypes { } Object v = javaObject.get(k); PythonObject pyVal; - pyVal = PythonTypes.convert(v); + if (v instanceof PythonObject) { + pyVal = (PythonObject) v; + } else if (v instanceof PyObject) { + pyVal = new PythonObject((PyObject) v); + } else { + pyVal = PythonTypes.convert(v); + } int errCode = PyDict_SetItem(pyDict, pyKey.getNativePythonObject(), pyVal.getNativePythonObject()); if (errCode != 0) { String keyStr = pyKey.toString(); @@ -341,4 +410,127 @@ public class PythonTypes { return new PythonObject(pyDict); } }; + + + public static final PythonType BYTES = new PythonType("bytes", byte[].class) { + @Override + public byte[] toJava(PythonObject pythonObject) { + try (PythonGC gc = PythonGC.watch()) { + if (!(Python.isinstance(pythonObject, Python.bytesType()))) { + throw new PythonException("Expected bytes. Received: " + pythonObject); + } + PythonObject pySize = Python.len(pythonObject); + byte[] ret = new byte[pySize.toInt()]; + for (int i = 0; i < ret.length; i++) { + ret[i] = (byte)pythonObject.get(i).toInt(); + } + return ret; + } + } + + @Override + public PythonObject toPython(byte[] javaObject) { + try(PythonGC gc = PythonGC.watch()){ + PythonObject ret = Python.bytes(LIST.toPython(LIST.adapt(javaObject))); + PythonGC.keep(ret); + return ret; + } + } + @Override + public boolean accepts(Object javaObject) { + return javaObject instanceof byte[]; + } + @Override + public byte[] adapt(Object javaObject) { + if (javaObject instanceof byte[]){ + return (byte[])javaObject; + } + throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to byte[]"); + } + + }; + + /** + * Crashes on Adopt OpenJDK + * Use implementation in python4j-numpy instead for zero-copy byte buffers. + */ +// public static final PythonType MEMORYVIEW = new PythonType("memoryview", BytePointer.class) { +// @Override +// public BytePointer toJava(PythonObject pythonObject) { +// try (PythonGC gc = PythonGC.watch()) { +// if (!(Python.isinstance(pythonObject, Python.memoryviewType()))) { +// throw new PythonException("Expected memoryview. Received: " + pythonObject); +// } +// PythonObject pySize = Python.len(pythonObject); +// PythonObject ctypes = Python.importModule("ctypes"); +// PythonObject charType = ctypes.attr("c_char"); +// PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(), +// pySize.getNativePythonObject())); +// PythonObject fromBuffer = charArrayType.attr("from_buffer"); +// if (pythonObject.attr("readonly").toBoolean()) { +// pythonObject = Python.bytearray(pythonObject); +// } +// PythonObject arr = fromBuffer.call(pythonObject); +// PythonObject cast = ctypes.attr("cast"); +// PythonObject voidPtrType = ctypes.attr("c_void_p"); +// PythonObject voidPtr = cast.call(arr, voidPtrType); +// long address = voidPtr.attr("value").toLong(); +// long size = pySize.toLong(); +// try { +// Field addressField = Buffer.class.getDeclaredField("address"); +// addressField.setAccessible(true); +// Field capacityField = Buffer.class.getDeclaredField("capacity"); +// capacityField.setAccessible(true); +// ByteBuffer buff = ByteBuffer.allocateDirect(0).order(ByteOrder.nativeOrder()); +// addressField.setLong(buff, address); +// capacityField.setInt(buff, (int) size); +// BytePointer ret = new BytePointer(buff); +// ret.limit(size); +// return ret; +// +// } catch (Exception e) { +// throw new RuntimeException(e); +// } +// +// } +// } +// +// @Override +// public PythonObject toPython(BytePointer javaObject) { +// long address = javaObject.address(); +// long size = javaObject.limit(); +// try (PythonGC gc = PythonGC.watch()) { +// PythonObject ctypes = Python.importModule("ctypes"); +// PythonObject charType = ctypes.attr("c_char"); +// PythonObject pySize = new PythonObject(size); +// PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(), +// pySize.getNativePythonObject())); +// PythonObject fromAddress = charArrayType.attr("from_address"); +// PythonObject arr = fromAddress.call(new PythonObject(address)); +// PythonObject memoryView = Python.memoryview(arr).attr("cast").call("b"); +// PythonGC.keep(memoryView); +// return memoryView; +// } +// +// } +// +// @Override +// public boolean accepts(Object javaObject) { +// return javaObject instanceof Pointer || javaObject instanceof DirectBuffer; +// } +// +// @Override +// public BytePointer adapt(Object javaObject) { +// if (javaObject instanceof BytePointer) { +// return (BytePointer) javaObject; +// } else if (javaObject instanceof Pointer) { +// return new BytePointer((Pointer) javaObject); +// } else if (javaObject instanceof DirectBuffer) { +// return new BytePointer((ByteBuffer) javaObject); +// } else { +// throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to BytePointer"); +// } +// } +// }; + } diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariable.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonVariable.java similarity index 98% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariable.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonVariable.java index 3deb4d2e7..038904ec9 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariable.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonVariable.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; @lombok.Data public class PythonVariable { diff --git a/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonVariables.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonVariables.java new file mode 100644 index 000000000..ed9ccff5d --- /dev/null +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonVariables.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.python4j; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Some syntax sugar for lookup by name + */ +public class PythonVariables extends ArrayList { + public PythonVariable get(String variableName) { + for (PythonVariable pyVar: this){ + if (pyVar.getName().equals(variableName)){ + return pyVar; + } + } + return null; + } + + public boolean add(String variableName, PythonType variableType, Object value){ + return this.add(new PythonVariable<>(variableName, variableType, value)); + } + + public PythonVariables(PythonVariable... variables){ + this(Arrays.asList(variables)); + } + public PythonVariables(List list){ + super(); + addAll(list); + } +} diff --git a/python4j/python4j-core/src/main/resources/org/nd4j/python4j/pythonexec/__init__.py b/python4j/python4j-core/src/main/resources/org/nd4j/python4j/pythonexec/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python4j/python4j-core/src/main/resources/org/eclipse/python4j/pythonexec/pythonexec.py b/python4j/python4j-core/src/main/resources/org/nd4j/python4j/pythonexec/pythonexec.py similarity index 100% rename from python4j/python4j-core/src/main/resources/org/eclipse/python4j/pythonexec/pythonexec.py rename to python4j/python4j-core/src/main/resources/org/nd4j/python4j/pythonexec/pythonexec.py diff --git a/python4j/python4j-core/src/test/java/PythonBasicExecutionTest.java b/python4j/python4j-core/src/test/java/PythonBasicExecutionTest.java index 9f5b43dba..c26b5c874 100644 --- a/python4j/python4j-core/src/test/java/PythonBasicExecutionTest.java +++ b/python4j/python4j-core/src/test/java/PythonBasicExecutionTest.java @@ -15,9 +15,12 @@ ******************************************************************************/ -import org.eclipse.python4j.*; import org.junit.Assert; import org.junit.Test; +import org.nd4j.python4j.PythonContextManager; +import org.nd4j.python4j.PythonExecutioner; +import org.nd4j.python4j.PythonTypes; +import org.nd4j.python4j.PythonVariable; import javax.annotation.concurrent.NotThreadSafe; import java.util.*; diff --git a/python4j/python4j-core/src/test/java/PythonCollectionsTest.java b/python4j/python4j-core/src/test/java/PythonCollectionsTest.java index 7e63d9d28..ba4d8e14a 100644 --- a/python4j/python4j-core/src/test/java/PythonCollectionsTest.java +++ b/python4j/python4j-core/src/test/java/PythonCollectionsTest.java @@ -15,9 +15,9 @@ ******************************************************************************/ -import org.eclipse.python4j.PythonException; -import org.eclipse.python4j.PythonObject; -import org.eclipse.python4j.PythonTypes; +import org.nd4j.python4j.PythonException; +import org.nd4j.python4j.PythonObject; +import org.nd4j.python4j.PythonTypes; import org.junit.Assert; import org.junit.Test; diff --git a/python4j/python4j-core/src/test/java/PythonContextManagerTest.java b/python4j/python4j-core/src/test/java/PythonContextManagerTest.java index a4451764c..4961f94d8 100644 --- a/python4j/python4j-core/src/test/java/PythonContextManagerTest.java +++ b/python4j/python4j-core/src/test/java/PythonContextManagerTest.java @@ -16,9 +16,9 @@ ******************************************************************************/ -import org.eclipse.python4j.Python; -import org.eclipse.python4j.PythonContextManager; -import org.eclipse.python4j.PythonExecutioner; +import org.nd4j.python4j.Python; +import org.nd4j.python4j.PythonContextManager; +import org.nd4j.python4j.PythonExecutioner; import org.junit.Assert; import org.junit.Test; import javax.annotation.concurrent.NotThreadSafe; diff --git a/python4j/python4j-core/src/test/java/PythonGCTest.java b/python4j/python4j-core/src/test/java/PythonGCTest.java index f8c6ecba5..11dd8e93a 100644 --- a/python4j/python4j-core/src/test/java/PythonGCTest.java +++ b/python4j/python4j-core/src/test/java/PythonGCTest.java @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -import org.eclipse.python4j.Python; -import org.eclipse.python4j.PythonGC; -import org.eclipse.python4j.PythonObject; +import org.nd4j.python4j.Python; +import org.nd4j.python4j.PythonGC; +import org.nd4j.python4j.PythonObject; import org.junit.Assert; import org.junit.Test; @@ -49,6 +49,6 @@ public class PythonGCTest { PythonObject pyObjCount3 = Python.len(getObjects.call()); long objCount3 = pyObjCount3.toLong(); diff = objCount3 - objCount2; - Assert.assertEquals(2, diff);// 2 objects created during function call + Assert.assertTrue(diff <= 2);// 2 objects created during function call } } diff --git a/python4j/python4j-core/src/test/java/PythonJobTest.java b/python4j/python4j-core/src/test/java/PythonJobTest.java index 016045a25..4dad7f24f 100644 --- a/python4j/python4j-core/src/test/java/PythonJobTest.java +++ b/python4j/python4j-core/src/test/java/PythonJobTest.java @@ -14,10 +14,10 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -import org.eclipse.python4j.PythonContextManager; -import org.eclipse.python4j.PythonJob; -import org.eclipse.python4j.PythonTypes; -import org.eclipse.python4j.PythonVariable; +import org.nd4j.python4j.PythonContextManager; +import org.nd4j.python4j.PythonJob; +import org.nd4j.python4j.PythonTypes; +import org.nd4j.python4j.PythonVariable; import org.junit.Test; import java.util.ArrayList; @@ -30,7 +30,7 @@ import static org.junit.Assert.assertEquals; public class PythonJobTest { @Test - public void testPythonJobBasic() throws Exception{ + public void testPythonJobBasic(){ PythonContextManager.deleteNonMainContexts(); String code = "c = a + b"; @@ -65,7 +65,7 @@ public class PythonJobTest { } @Test - public void testPythonJobReturnAllVariables()throws Exception{ + public void testPythonJobReturnAllVariables(){ PythonContextManager.deleteNonMainContexts(); String code = "c = a + b"; @@ -101,7 +101,7 @@ public class PythonJobTest { @Test - public void testMultiplePythonJobsParallel()throws Exception{ + public void testMultiplePythonJobsParallel(){ PythonContextManager.deleteNonMainContexts(); String code1 = "c = a + b"; PythonJob job1 = new PythonJob("job1", code1, false); @@ -150,7 +150,7 @@ public class PythonJobTest { @Test - public void testPythonJobSetupRun()throws Exception{ + public void testPythonJobSetupRun(){ PythonContextManager.deleteNonMainContexts(); String code = "five=None\n" + @@ -189,7 +189,7 @@ public class PythonJobTest { } @Test - public void testPythonJobSetupRunAndReturnAllVariables()throws Exception{ + public void testPythonJobSetupRunAndReturnAllVariables(){ PythonContextManager.deleteNonMainContexts(); String code = "five=None\n" + "c=None\n"+ @@ -225,7 +225,7 @@ public class PythonJobTest { } @Test - public void testMultiplePythonJobsSetupRunParallel()throws Exception{ + public void testMultiplePythonJobsSetupRunParallel(){ PythonContextManager.deleteNonMainContexts(); String code1 = "five=None\n" + diff --git a/python4j/python4j-core/src/test/java/PythonMultiThreadTest.java b/python4j/python4j-core/src/test/java/PythonMultiThreadTest.java index ec544b65f..b2f9089fa 100644 --- a/python4j/python4j-core/src/test/java/PythonMultiThreadTest.java +++ b/python4j/python4j-core/src/test/java/PythonMultiThreadTest.java @@ -14,10 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -import org.eclipse.python4j.*; +import org.nd4j.python4j.*; import org.junit.Assert; import org.junit.Test; - import javax.annotation.concurrent.NotThreadSafe; import java.util.ArrayList; import java.util.Arrays; diff --git a/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java b/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java index ae10ed8dc..94423f7de 100644 --- a/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java +++ b/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java @@ -15,12 +15,13 @@ ******************************************************************************/ -import org.eclipse.python4j.PythonException; -import org.eclipse.python4j.PythonObject; -import org.eclipse.python4j.PythonTypes; +import org.nd4j.python4j.*; import org.junit.Assert; import org.junit.Test; +import java.util.ArrayList; +import java.util.List; + public class PythonPrimitiveTypesTest { @Test @@ -78,5 +79,18 @@ public class PythonPrimitiveTypesTest { Assert.assertEquals(b, b3); } + @Test + public void testBytes() { + byte[] bytes = new byte[]{97, 98, 99}; + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("buff", PythonTypes.BYTES, bytes)); + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); + outputs.add(new PythonVariable<>("buff2", PythonTypes.BYTES)); + String code = "s1 = ''.join(chr(c) for c in buff)\nbuff2=b'def'"; + PythonExecutioner.exec(code, inputs, outputs); + Assert.assertEquals("abc", outputs.get(0).getValue()); + Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[])outputs.get(1).getValue()); + } } diff --git a/python4j/python4j-numpy/pom.xml b/python4j/python4j-numpy/pom.xml index 527a9343f..c631f67e3 100644 --- a/python4j/python4j-numpy/pom.xml +++ b/python4j/python4j-numpy/pom.xml @@ -4,7 +4,7 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> python4j-parent - org.eclipse + org.nd4j 1.0.0-SNAPSHOT 4.0.0 @@ -28,15 +28,50 @@ ${nd4j.version} test + + org.nd4j + python4j-core + 1.0.0-SNAPSHOT + test-nd4j-native + + + org.nd4j + nd4j-native + ${nd4j.version} + test + + + org.deeplearning4j + dl4j-test-resources + ${nd4j.version} + test + + + test-nd4j-cuda-10.2 + + + org.nd4j + nd4j-cuda-10.1 + ${nd4j.version} + test + + + org.deeplearning4j + dl4j-test-resources + ${nd4j.version} + test + + + \ No newline at end of file diff --git a/python4j/python4j-numpy/src/main/java/org/nd4j/python4j/NumpyArray.java b/python4j/python4j-numpy/src/main/java/org/nd4j/python4j/NumpyArray.java new file mode 100644 index 000000000..b21dabd7c --- /dev/null +++ b/python4j/python4j-numpy/src/main/java/org/nd4j/python4j/NumpyArray.java @@ -0,0 +1,299 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.nd4j.python4j; + +import lombok.extern.slf4j.Slf4j; +import org.bytedeco.cpython.PyObject; +import org.bytedeco.cpython.PyTypeObject; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.SizeTPointer; +import org.bytedeco.numpy.PyArrayObject; +import org.bytedeco.numpy.global.numpy; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.concurrency.AffinityManager; +import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.nativeblas.NativeOpsHolder; + +import java.io.File; +import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.bytedeco.cpython.global.python.*; +import static org.bytedeco.cpython.global.python.Py_DecRef; +import static org.bytedeco.numpy.global.numpy.*; +import static org.bytedeco.numpy.global.numpy.NPY_ARRAY_CARRAY; +import static org.bytedeco.numpy.global.numpy.PyArray_Type; + +@Slf4j +public class NumpyArray extends PythonType { + + public static final NumpyArray INSTANCE; + private static final AtomicBoolean init = new AtomicBoolean(false); + private static final Map cache = new HashMap<>(); + + static { + new PythonExecutioner(); + INSTANCE = new NumpyArray(); + } + + @Override + public File[] packages(){ + try{ + return new File[]{numpy.cachePackage()}; + }catch(Exception e){ + throw new PythonException(e); + } + + } + + public synchronized void init() { + if (init.get()) return; + init.set(true); + if (PythonGIL.locked()) { + throw new PythonException("Can not initialize numpy - GIL already acquired."); + } + int err = numpy._import_array(); + if (err < 0){ + System.out.println("Numpy import failed!"); + throw new PythonException("Numpy import failed!"); + } + } + + public NumpyArray() { + super("numpy.ndarray", INDArray.class); + + } + + @Override + public INDArray toJava(PythonObject pythonObject) { + log.info("Converting PythonObject to INDArray..."); + PyObject np = PyImport_ImportModule("numpy"); + PyObject ndarray = PyObject_GetAttrString(np, "ndarray"); + if (PyObject_IsInstance(pythonObject.getNativePythonObject(), ndarray) != 1) { + Py_DecRef(ndarray); + Py_DecRef(np); + throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array."); + } + Py_DecRef(ndarray); + Py_DecRef(np); + PyArrayObject npArr = new PyArrayObject(pythonObject.getNativePythonObject()); + long[] shape = new long[PyArray_NDIM(npArr)]; + SizeTPointer shapePtr = PyArray_SHAPE(npArr); + if (shapePtr != null) + shapePtr.get(shape, 0, shape.length); + long[] strides = new long[shape.length]; + SizeTPointer stridesPtr = PyArray_STRIDES(npArr); + if (stridesPtr != null) + stridesPtr.get(strides, 0, strides.length); + int npdtype = PyArray_TYPE(npArr); + + DataType dtype; + switch (npdtype) { + case NPY_DOUBLE: + dtype = DataType.DOUBLE; + break; + case NPY_FLOAT: + dtype = DataType.FLOAT; + break; + case NPY_SHORT: + dtype = DataType.SHORT; + break; + case NPY_INT: + dtype = DataType.INT32; + break; + case NPY_LONG: + dtype = DataType.INT64; + break; + case NPY_UINT: + dtype = DataType.UINT32; + break; + case NPY_BYTE: + dtype = DataType.INT8; + break; + case NPY_UBYTE: + dtype = DataType.UINT8; + break; + case NPY_BOOL: + dtype = DataType.BOOL; + break; + case NPY_HALF: + dtype = DataType.FLOAT16; + break; + case NPY_LONGLONG: + dtype = DataType.INT64; + break; + case NPY_USHORT: + dtype = DataType.UINT16; + break; + case NPY_ULONG: + case NPY_ULONGLONG: + dtype = DataType.UINT64; + break; + default: + throw new PythonException("Unsupported array data type: " + npdtype); + } + long size = 1; + for (int i = 0; i < shape.length; size *= shape[i++]) ; + + INDArray ret; + long address = PyArray_DATA(npArr).address(); + String key = address + "_" + size + "_" + dtype; + DataBuffer buff = cache.get(key); + if (buff == null) { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address); + ptr = ptr.limit(size); + ptr = ptr.capacity(size); + buff = Nd4j.createBuffer(ptr, size, dtype); + cache.put(key, buff); + } + } + int elemSize = buff.getElementSize(); + long[] nd4jStrides = new long[strides.length]; + for (int i = 0; i < strides.length; i++) { + nd4jStrides[i] = strides[i] / elemSize; + } + ret = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype); + Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.HOST); + log.info("Done."); + return ret; + + + } + + @Override + public PythonObject toPython(INDArray indArray) { + log.info("Converting INDArray to PythonObject..."); + DataType dataType = indArray.dataType(); + DataBuffer buff = indArray.data(); + String key = buff.pointer().address() + "_" + buff.length() + "_" + dataType; + cache.put(key, buff); + int numpyType; + String ctype; + switch (dataType) { + case DOUBLE: + numpyType = NPY_DOUBLE; + ctype = "c_double"; + break; + case FLOAT: + case BFLOAT16: + numpyType = NPY_FLOAT; + ctype = "c_float"; + break; + case SHORT: + numpyType = NPY_SHORT; + ctype = "c_short"; + break; + case INT: + numpyType = NPY_INT; + ctype = "c_int"; + break; + case LONG: + numpyType = NPY_INT64; + ctype = "c_int64"; + break; + case UINT16: + numpyType = NPY_USHORT; + ctype = "c_uint16"; + break; + case UINT32: + numpyType = NPY_UINT; + ctype = "c_uint"; + break; + case UINT64: + numpyType = NPY_UINT64; + ctype = "c_uint64"; + break; + case BOOL: + numpyType = NPY_BOOL; + ctype = "c_bool"; + break; + case BYTE: + numpyType = NPY_BYTE; + ctype = "c_byte"; + break; + case UBYTE: + numpyType = NPY_UBYTE; + ctype = "c_ubyte"; + break; + case HALF: + numpyType = NPY_HALF; + ctype = "c_short"; + break; + default: + throw new RuntimeException("Unsupported dtype: " + dataType); + } + + long[] shape = indArray.shape(); + INDArray inputArray = indArray; + if (dataType == DataType.BFLOAT16) { + log.warn("Creating copy of array as bfloat16 is not supported by numpy."); + inputArray = indArray.castTo(DataType.FLOAT); + } + + //Sync to host memory in the case of CUDA, before passing the host memory pointer to Python + + Nd4j.getAffinityManager().ensureLocation(inputArray, AffinityManager.Location.HOST); + + // PyArray_Type() call causes jvm crash in linux cpu if GIL is acquired by non main thread. + // Using Interpreter for now: + +// try(PythonContextManager.Context context = new PythonContextManager.Context("__np_array_converter")){ +// log.info("Stringing exec..."); +// String code = "import ctypes\nimport numpy as np\n" + +// "cArr = (ctypes." + ctype + "*" + indArray.length() + ")"+ +// ".from_address(" + indArray.data().pointer().address() + ")\n"+ +// "npArr = np.frombuffer(cArr, dtype=" + ((numpyType == NPY_HALF) ? "'half'" : "ctypes." + ctype)+ +// ").reshape(" + Arrays.toString(indArray.shape()) + ")"; +// PythonExecutioner.exec(code); +// log.info("exec done."); +// PythonObject ret = PythonExecutioner.getVariable("npArr"); +// Py_IncRef(ret.getNativePythonObject()); +// return ret; +// +// } + log.info("NUMPY: PyArray_Type()"); + PyTypeObject pyTypeObject = PyArray_Type(); + + + log.info("NUMPY: PyArray_New()"); + PyObject npArr = PyArray_New(pyTypeObject, shape.length, new SizeTPointer(shape), + numpyType, null, + inputArray.data().addressPointer(), + 0, NPY_ARRAY_CARRAY, null); + log.info("Done."); + return new PythonObject(npArr); + } + + @Override + public boolean accepts(Object javaObject) { + return javaObject instanceof INDArray; + } + + @Override + public INDArray adapt(Object javaObject) { + if (javaObject instanceof INDArray) { + return (INDArray) javaObject; + } + throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to INDArray"); + } +} diff --git a/python4j/python4j-numpy/src/main/resources/META-INF/services/org.nd4j.python4j.PythonType b/python4j/python4j-numpy/src/main/resources/META-INF/services/org.nd4j.python4j.PythonType new file mode 100644 index 000000000..b0d2f1256 --- /dev/null +++ b/python4j/python4j-numpy/src/main/resources/META-INF/services/org.nd4j.python4j.PythonType @@ -0,0 +1 @@ +org.nd4j.python4j.NumpyArray \ No newline at end of file diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java new file mode 100644 index 000000000..d76f759a6 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java @@ -0,0 +1,169 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +import org.nd4j.python4j.*; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.nativeblas.OpaqueDataBuffer; + +import javax.annotation.concurrent.NotThreadSafe; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +@NotThreadSafe +@RunWith(Parameterized.class) +public class PythonNumpyBasicTest { + private DataType dataType; + private long[] shape; + + public PythonNumpyBasicTest(DataType dataType, long[] shape, String dummyArg) { + this.dataType = dataType; + this.shape = shape; + } + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}, shape={2}") + public static Collection params() { + DataType[] types = new DataType[] { + DataType.BOOL, + DataType.FLOAT16, + DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, + DataType.INT8, + DataType.INT16, + DataType.INT32, + DataType.INT64, + DataType.UINT8, + DataType.UINT16, + DataType.UINT32, + DataType.UINT64 + }; + + long[][] shapes = new long[][]{ + new long[]{2, 3}, + new long[]{3}, + new long[]{1}, + new long[]{} // scalar + }; + + + List ret = new ArrayList<>(); + for (DataType type: types){ + for (long[] shape: shapes){ + ret.add(new Object[]{type, shape, Arrays.toString(shape)}); + } + } + return ret; + } + + @Test + public void testConversion(){ + INDArray arr = Nd4j.zeros(dataType, shape); + PythonObject npArr = PythonTypes.convert(arr); + INDArray arr2 = PythonTypes.getPythonTypeForPythonObject(npArr).toJava(npArr); + if (dataType == DataType.BFLOAT16){ + arr = arr.castTo(DataType.FLOAT); + } + Assert.assertEquals(arr,arr2); + } + + + @Test + public void testExecution(){ + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, shape); + INDArray y = Nd4j.zeros(dataType, shape); + INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2)); + z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z; + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + List outputs = new ArrayList<>(); + PythonVariable output = new PythonVariable<>("z", arrType); + outputs.add(output); + String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)"; + if (shape.length == 0){ // scalar special case + code += "\nimport numpy as np\nz = np.asarray(float(z), dtype=x.dtype)"; + } + PythonExecutioner.exec(code, inputs, outputs); + INDArray z2 = output.getValue(); + + Assert.assertEquals(z.dataType(), z2.dataType()); + Assert.assertEquals(z, z2); + + } + + + @Test + public void testInplaceExecution(){ + if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return; + if (shape.length == 0) return; + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, shape); + INDArray y = Nd4j.zeros(dataType, shape); + INDArray z = x.mul(y.add(2)); + // Nd4j.getAffinityManager().ensureLocation(z, AffinityManager.Location.HOST); + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + List outputs = new ArrayList<>(); + PythonVariable output = new PythonVariable<>("x", arrType); + outputs.add(output); + String code = "x *= y + 2"; + PythonExecutioner.exec(code, inputs, outputs); + INDArray z2 = output.getValue(); + Assert.assertEquals(x.dataType(), z2.dataType()); + Assert.assertEquals(z.dataType(), z2.dataType()); + Assert.assertEquals(x, z2); + Assert.assertEquals(z, z2); + Assert.assertEquals(x.data().pointer().address(), z2.data().pointer().address()); + if("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){ + Assert.assertEquals(getDeviceAddress(x), getDeviceAddress(z2)); + } + + + } + private static long getDeviceAddress(INDArray array){ + if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){ + throw new IllegalStateException("Cannot ge device pointer for non-CUDA device"); + } + + //Use reflection here as OpaqueDataBuffer is only available on BaseCudaDataBuffer and BaseCpuDataBuffer - not DataBuffer/BaseDataBuffer + // due to it being defined in nd4j-native-api, not nd4j-api + try { + Class c = Class.forName("org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer"); + Method m = c.getMethod("getOpaqueDataBuffer"); + OpaqueDataBuffer db = (OpaqueDataBuffer) m.invoke(array.data()); + long address = db.specialBuffer().address(); + return address; + } catch (Throwable t){ + throw new RuntimeException("Error getting OpaqueDataBuffer", t); + } + } + + + + +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java new file mode 100644 index 000000000..64c417905 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java @@ -0,0 +1,96 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +import org.nd4j.python4j.PythonException; +import org.nd4j.python4j.PythonObject; +import org.nd4j.python4j.PythonTypes; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4j; + +import javax.annotation.concurrent.NotThreadSafe; +import java.util.*; + + +@NotThreadSafe +@RunWith(Parameterized.class) +public class PythonNumpyCollectionsTest { + private DataType dataType; + + public PythonNumpyCollectionsTest(DataType dataType){ + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}") + public static DataType[] params() { + return new DataType[]{ + DataType.BOOL, + DataType.FLOAT16, + //DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, + DataType.INT8, + DataType.INT16, + DataType.INT32, + DataType.INT64, + DataType.UINT8, + DataType.UINT16, + DataType.UINT32, + DataType.UINT64 + }; + } + @Test + public void testPythonDictFromMap() throws PythonException { + Map map = new HashMap(); + map.put("a", 1); + map.put(1, "a"); + map.put("arr", Nd4j.ones(dataType, 2, 3)); + map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType,3,2))); + Map innerMap = new HashMap(); + innerMap.put("b", 2); + innerMap.put(2, "b"); + innerMap.put(5, Nd4j.ones(dataType, 5)); + map.put("innermap", innerMap); + map.put("list2", Arrays.asList(4, "5", innerMap, false, true)); + PythonObject dict = PythonTypes.convert(map); + Map map2 = PythonTypes.DICT.toJava(dict); + Assert.assertEquals(map.toString(), map2.toString()); + } + + @Test + public void testPythonListFromList() throws PythonException{ + List list = new ArrayList<>(); + list.add(1); + list.add("2"); + list.add(Nd4j.ones(dataType, 2, 3)); + list.add(Arrays.asList("a", + Nd4j.ones(dataType, 1, 2),1.0, 2f, 10, true, false, + Nd4j.zeros(dataType, 3, 2))); + Map map = new HashMap(); + map.put("a", 1); + map.put(1, "a"); + map.put(5, Nd4j.ones(dataType,4, 5)); + map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType, 3, 1))); + list.add(map); + PythonObject dict = PythonTypes.convert(list); + List list2 = PythonTypes.LIST.toJava(dict); + Assert.assertEquals(list.toString(), list2.toString()); + } +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java new file mode 100644 index 000000000..96dd7274c --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java @@ -0,0 +1,55 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import org.nd4j.python4j.Python; +import org.nd4j.python4j.PythonGC; +import org.nd4j.python4j.PythonObject; +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.factory.Nd4j; + +import javax.annotation.concurrent.NotThreadSafe; + + +@NotThreadSafe +public class PythonNumpyGCTest { + + @Test + public void testGC(){ + PythonObject gcModule = Python.importModule("gc"); + PythonObject getObjects = gcModule.attr("get_objects"); + PythonObject pyObjCount1 = Python.len(getObjects.call()); + long objCount1 = pyObjCount1.toLong(); + PythonObject pyList = Python.list(); + pyList.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10))); + pyList.attr("append").call(1.0); + pyList.attr("append").call(true); + PythonObject pyObjCount2 = Python.len(getObjects.call()); + long objCount2 = pyObjCount2.toLong(); + long diff = objCount2 - objCount1; + Assert.assertTrue(diff > 2); + try(PythonGC gc = PythonGC.watch()){ + PythonObject pyList2 = Python.list(); + pyList2.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10))); + pyList2.attr("append").call(1.0); + pyList2.attr("append").call(true); + } + PythonObject pyObjCount3 = Python.len(getObjects.call()); + long objCount3 = pyObjCount3.toLong(); + diff = objCount3 - objCount2; + Assert.assertTrue(diff <= 2);// 2 objects created during function call + } +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java new file mode 100644 index 000000000..941072e45 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java @@ -0,0 +1,22 @@ +import org.nd4j.python4j.NumpyArray; +import org.nd4j.python4j.Python; +import org.nd4j.python4j.PythonGC; +import org.nd4j.python4j.PythonObject; +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +public class PythonNumpyImportTest { + + @Test + public void testNumpyImport(){ + try(PythonGC gc = PythonGC.watch()){ + PythonObject np = Python.importModule("numpy"); + PythonObject zeros = np.attr("zeros").call(5); + INDArray arr = NumpyArray.INSTANCE.toJava(zeros); + Assert.assertEquals(arr, Nd4j.zeros(DataType.DOUBLE, 5)); + } + } +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java new file mode 100644 index 000000000..dc087d0f8 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java @@ -0,0 +1,303 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.python4j.*; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + + +@javax.annotation.concurrent.NotThreadSafe +@RunWith(Parameterized.class) +public class PythonNumpyJobTest { + private DataType dataType; + + public PythonNumpyJobTest(DataType dataType){ + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}") + public static DataType[] params() { + return new DataType[]{ + DataType.BOOL, + DataType.FLOAT16, + DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, + DataType.INT8, + DataType.INT16, + DataType.INT32, + DataType.INT64, + DataType.UINT8, + DataType.UINT16, + DataType.UINT32, + DataType.UINT64 + }; + } + + @Test + public void testNumpyJobBasic(){ + PythonContextManager.deleteNonMainContexts(); + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, 2, 3); + INDArray y = Nd4j.zeros(dataType, 2, 3); + INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2)); + z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z; + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + List outputs = new ArrayList<>(); + PythonVariable output = new PythonVariable<>("z", arrType); + outputs.add(output); + String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)"; + + PythonJob job = new PythonJob("job1", code, false); + + job.exec(inputs, outputs); + + INDArray z2 = output.getValue(); + + if (dataType == DataType.BFLOAT16){ + z2 = z2.castTo(DataType.FLOAT); + } + + Assert.assertEquals(z, z2); + + } + + @Test + public void testNumpyJobReturnAllVariables(){ + PythonContextManager.deleteNonMainContexts(); + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, 2, 3); + INDArray y = Nd4j.zeros(dataType, 2, 3); + INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2)); + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)"; + + PythonJob job = new PythonJob("job1", code, false); + List outputs = job.execAndReturnAllVariables(inputs); + + INDArray x2 = (INDArray) outputs.get(0).getValue(); + INDArray y2 = (INDArray) outputs.get(1).getValue(); + INDArray z2 = (INDArray) outputs.get(2).getValue(); + + if (dataType == DataType.BFLOAT16){ + x = x.castTo(DataType.FLOAT); + y = y.castTo(DataType.FLOAT); + z = z.castTo(DataType.FLOAT); + } + Assert.assertEquals(x, x2); + Assert.assertEquals(y, y2); + Assert.assertEquals(z, z2); + + } + + + @Test + public void testMultipleNumpyJobsParallel(){ + PythonContextManager.deleteNonMainContexts(); + String code1 =(dataType == DataType.BOOL)?"z = x":"z = x + y"; + PythonJob job1 = new PythonJob("job1", code1, false); + + String code2 =(dataType == DataType.BOOL)?"z = y":"z = x - y"; + PythonJob job2 = new PythonJob("job2", code2, false); + + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, 2, 3); + INDArray y = Nd4j.zeros(dataType, 2, 3); + INDArray z1 = (dataType == DataType.BOOL)?x:x.add(y); + z1 = (dataType == DataType.BFLOAT16)? z1.castTo(DataType.FLOAT): z1; + INDArray z2 = (dataType == DataType.BOOL)?y:x.sub(y); + z2 = (dataType == DataType.BFLOAT16)? z2.castTo(DataType.FLOAT): z2; + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + + + List outputs = new ArrayList<>(); + + outputs.add(new PythonVariable<>("z", arrType)); + + job1.exec(inputs, outputs); + + assertEquals(z1, outputs.get(0).getValue()); + + + job2.exec(inputs, outputs); + + assertEquals(z2, outputs.get(0).getValue()); + + } + + + @Test + public synchronized void testNumpyJobSetupRun(){ + if (dataType == DataType.BOOL)return; + PythonContextManager.deleteNonMainContexts(); + String code = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n"; + + PythonJob job = new PythonJob("job1", code, true); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); + job.exec(inputs, outputs); + + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10), + outputs.get(0).getValue()); + + + inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + + + outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); + + job.exec(inputs, outputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12), + outputs.get(0).getValue()); + + + } + @Test + public void testNumpyJobSetupRunAndReturnAllVariables(){ + if (dataType == DataType.BOOL)return; + PythonContextManager.deleteNonMainContexts(); + String code = "five=None\n" + + "c=None\n"+ + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " global c\n" + + " c = a + b + five\n"; + PythonJob job = new PythonJob("job1", code, true); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + + List outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10), + outputs.get(1).getValue()); + + + inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + + + outputs = job.execAndReturnAllVariables(inputs); + + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12), + outputs.get(1).getValue()); + + + } + + @Test + public void testMultipleNumpyJobsSetupRunParallel(){ + if (dataType == DataType.BOOL)return; + PythonContextManager.deleteNonMainContexts(); + + String code1 = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n"; + PythonJob job1 = new PythonJob("job1", code1, true); + + String code2 = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b - five\n"+ + " return {'c':c}\n\n"; + PythonJob job2 = new PythonJob("job2", code2, true); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + + + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); + + job1.exec(inputs, outputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10), + outputs.get(0).getValue()); + + + job2.exec(inputs, outputs); + + assertEquals(Nd4j.zeros((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3), + outputs.get(0).getValue()); + + + inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + + outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); + + + job1.exec(inputs, outputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12), + outputs.get(0).getValue()); + + + job2.exec(inputs, outputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(2), + outputs.get(0).getValue()); + + + } + +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java new file mode 100644 index 000000000..02eb99551 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java @@ -0,0 +1,194 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import org.nd4j.python4j.*; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import javax.annotation.concurrent.NotThreadSafe; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + + +@NotThreadSafe +@RunWith(Parameterized.class) +public class PythonNumpyMultiThreadTest { + private DataType dataType; + + public PythonNumpyMultiThreadTest(DataType dataType) { + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}") + public static DataType[] params() { + return new DataType[]{ +// DataType.BOOL, +// DataType.FLOAT16, +// DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, +// DataType.INT8, +// DataType.INT16, + DataType.INT32, + DataType.INT64, +// DataType.UINT8, +// DataType.UINT16, +// DataType.UINT32, +// DataType.UINT64 + }; + } + + + @Test + public void testMultiThreading1() throws Throwable { + final List exceptions = Collections.synchronizedList(new ArrayList()); + Runnable runnable = new Runnable() { + @Override + public void run() { + try (PythonGIL gil = PythonGIL.lock()) { + try (PythonGC gc = PythonGC.watch()) { + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + PythonVariable out = new PythonVariable<>("z", NumpyArray.INSTANCE); + String code = "z = x + y"; + PythonExecutioner.exec(code, inputs, Collections.singletonList(out)); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), out.getValue()); + } + } catch (Throwable e) { + exceptions.add(e); + } + } + }; + + int numThreads = 10; + Thread[] threads = new Thread[numThreads]; + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(runnable); + } + for (int i = 0; i < threads.length; i++) { + threads[i].start(); + } + Thread.sleep(100); + for (int i = 0; i < threads.length; i++) { + threads[i].join(); + } + if (!exceptions.isEmpty()) { + throw (exceptions.get(0)); + } + + } + + @Test + public void testMultiThreading2() throws Throwable { + final List exceptions = Collections.synchronizedList(new ArrayList()); + Runnable runnable = new Runnable() { + @Override + public void run() { + try (PythonGIL gil = PythonGIL.lock()) { + try (PythonGC gc = PythonGC.watch()) { + PythonContextManager.reset(); + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + String code = "z = x + y"; + List outputs = PythonExecutioner.execAndReturnAllVariables(code, inputs); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(3), outputs.get(0).getValue()); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(4), outputs.get(1).getValue()); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), outputs.get(2).getValue()); + } + } catch (Throwable e) { + exceptions.add(e); + } + } + }; + + int numThreads = 10; + Thread[] threads = new Thread[numThreads]; + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(runnable); + } + for (int i = 0; i < threads.length; i++) { + threads[i].start(); + } + Thread.sleep(100); + for (int i = 0; i < threads.length; i++) { + threads[i].join(); + } + if (!exceptions.isEmpty()) { + throw (exceptions.get(0)); + } + } + + @Test + public void testMultiThreading3() throws Throwable { + PythonContextManager.deleteNonMainContexts(); + + String code = "c = a + b"; + final PythonJob job = new PythonJob("job1", code, false); + + final List exceptions = Collections.synchronizedList(new ArrayList()); + + class JobThread extends Thread { + private INDArray a, b, c; + + public JobThread(INDArray a, INDArray b, INDArray c) { + this.a = a; + this.b = b; + this.c = c; + } + + @Override + public void run() { + try { + PythonVariable out = new PythonVariable<>("c", NumpyArray.INSTANCE); + job.exec(Arrays.asList(new PythonVariable<>("a", NumpyArray.INSTANCE, a), + new PythonVariable<>("b", NumpyArray.INSTANCE, b)), + Collections.singletonList(out)); + Assert.assertEquals(c, out.getValue()); + } catch (Exception e) { + exceptions.add(e); + } + + } + } + int numThreads = 10; + JobThread[] threads = new JobThread[numThreads]; + for (int i = 0; i < threads.length; i++) { + threads[i] = new JobThread(Nd4j.zeros(dataType, 2, 3).add(i), Nd4j.zeros(dataType, 2, 3).add(i + 3), + Nd4j.zeros(dataType, 2, 3).add(2 * i + 3)); + } + + for (int i = 0; i < threads.length; i++) { + threads[i].start(); + } + Thread.sleep(100); + for (int i = 0; i < threads.length; i++) { + threads[i].join(); + } + + if (!exceptions.isEmpty()) { + throw (exceptions.get(0)); + } + } +} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java b/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java similarity index 54% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java rename to python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java index 34fd9c06e..bd13a99d9 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,15 +14,22 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.rl4j.learning.sync.qlearning; -import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.python4j.NumpyArray; +import org.nd4j.python4j.PythonTypes; -/** - * An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network - * - * @author Alexandre Boulanger - */ -public interface TargetQNetworkSource extends QNetworkSource { - IDQN getTargetQNetwork(); +import javax.annotation.concurrent.NotThreadSafe; + +@NotThreadSafe +public class PythonNumpyServiceLoaderTest { + + @Test + public void testServiceLoader(){ + Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.get("numpy.ndarray")); + Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.getPythonTypeForJavaObject(Nd4j.zeros(1))); + } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java index 46123d645..98873b827 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java @@ -17,7 +17,6 @@ package org.deeplearning4j.rl4j.agent.update; import lombok.Getter; import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; @@ -28,13 +27,10 @@ import java.util.List; // Temporary class that will be replaced with a more generic class that delegates gradient computation // and network update to sub components. -public class DQNNeuralNetUpdateRule implements IUpdateRule>, TargetQNetworkSource { +public class DQNNeuralNetUpdateRule implements IUpdateRule> { - @Getter private final IDQN qNetwork; - - @Getter - private IDQN targetQNetwork; + private final IDQN targetQNetwork; private final int targetUpdateFrequency; private final ITDTargetAlgorithm tdTargetAlgorithm; @@ -47,8 +43,8 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule>, this.targetQNetwork = qNetwork.clone(); this.targetUpdateFrequency = targetUpdateFrequency; tdTargetAlgorithm = isDoubleDQN - ? new DoubleDQN(this, gamma, errorClamp) - : new StandardDQN(this, gamma, errorClamp); + ? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp) + : new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp); } @Override @@ -56,7 +52,7 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule>, DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch); qNetwork.fit(targets.getFeatures(), targets.getLabels()); if(++updateCount % targetUpdateFrequency == 0) { - targetQNetwork = qNetwork.clone(); + targetQNetwork.copy(qNetwork); } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java index 3f27f954c..6cae384d5 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java @@ -16,8 +16,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; -import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -28,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; */ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm { - private final TargetQNetworkSource qTargetNetworkSource; + private final IOutputNeuralNet targetQNetwork; /** * In litterature, this corresponds to Q{net}(s(t+1), a) @@ -40,23 +39,21 @@ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm { */ protected INDArray targetQNetworkNextObservation; - protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) { - super(qTargetNetworkSource, gamma); - this.qTargetNetworkSource = qTargetNetworkSource; + protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) { + super(qNetwork, gamma); + this.targetQNetwork = targetQNetwork; } - protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { - super(qTargetNetworkSource, gamma, errorClamp); - this.qTargetNetworkSource = qTargetNetworkSource; + protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) { + super(qNetwork, gamma, errorClamp); + this.targetQNetwork = targetQNetwork; } @Override protected void initComputation(INDArray observations, INDArray nextObservations) { super.initComputation(observations, nextObservations); - qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations); - - IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork(); + qNetworkNextObservation = qNetwork.output(nextObservations); targetQNetworkNextObservation = targetQNetwork.output(nextObservations); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java index ca4beb47e..e0ede18d7 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java @@ -17,7 +17,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; @@ -30,7 +30,7 @@ import java.util.List; */ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm { - protected final QNetworkSource qNetworkSource; + protected final IOutputNeuralNet qNetwork; protected final double gamma; private final double errorClamp; @@ -38,12 +38,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithmerrorClamp away from the previous value. Double.NaN will disable the clamping. */ - protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) { - this.qNetworkSource = qNetworkSource; + protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma, double errorClamp) { + this.qNetwork = qNetwork; this.gamma = gamma; this.errorClamp = errorClamp; @@ -52,12 +52,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm transition = transitions.get(i); double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal()); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java index 3203af1b8..caeb85fb6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java @@ -16,7 +16,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,12 +32,12 @@ public class DoubleDQN extends BaseDQNAlgorithm { // In litterature, this corresponds to: max_{a}Q(s_{t+1}, a) private INDArray maxActionsFromQNetworkNextObservation; - public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) { - super(qTargetNetworkSource, gamma); + public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) { + super(qNetwork, targetQNetwork, gamma); } - public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { - super(qTargetNetworkSource, gamma, errorClamp); + public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) { + super(qNetwork, targetQNetwork, gamma, errorClamp); } @Override diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java index 8c03c8de9..6cd047c74 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java @@ -16,7 +16,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,12 +32,12 @@ public class StandardDQN extends BaseDQNAlgorithm { // In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a) private INDArray maxActionsFromQTargetNextObservation; - public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) { - super(qTargetNetworkSource, gamma); + public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) { + super(qNetwork, targetQNetwork, gamma); } - public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { - super(qTargetNetworkSource, gamma, errorClamp); + public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) { + super(qNetwork, targetQNetwork, gamma, errorClamp); } @Override diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QNetworkSource.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java similarity index 51% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QNetworkSource.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java index e22d368e4..58e219ea0 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QNetworkSource.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java @@ -1,28 +1,38 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning.sync.qlearning; - -import org.deeplearning4j.rl4j.network.dqn.IDQN; - -/** - * An interface for all implementations capable of supplying a Q-Network - * - * @author Alexandre Boulanger - */ -public interface QNetworkSource { - IDQN getQNetwork(); -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.network; + +import org.deeplearning4j.rl4j.observation.Observation; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * An interface defining the output aspect of a {@link NeuralNet}. + */ +public interface IOutputNeuralNet { + /** + * Compute the output for the supplied observation. + * @param observation An {@link Observation} + * @return The ouptut of the network + */ + INDArray output(Observation observation); + + /** + * Compute the output for the supplied batch. + * @param batch + * @return The ouptut of the network + */ + INDArray output(INDArray batch); +} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java index af295d202..daed646c5 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java @@ -17,6 +17,7 @@ package org.deeplearning4j.rl4j.network.dqn; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * This neural net quantify the value of each action given a state * */ -public interface IDQN extends NeuralNet { +public interface IDQN extends NeuralNet, IOutputNeuralNet { boolean isRecurrent(); @@ -37,9 +38,6 @@ public interface IDQN extends NeuralNet { void fit(INDArray input, INDArray[] labels); - INDArray output(INDArray batch); - INDArray output(Observation observation); - INDArray[] outputAll(INDArray batch); NN clone(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java index 798bddf0d..0f03a5370 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java @@ -1,10 +1,13 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; -import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -13,16 +16,29 @@ import java.util.ArrayList; import java.util.List; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +@RunWith(MockitoJUnitRunner.class) public class DoubleDQNTest { + @Mock + IOutputNeuralNet qNetworkMock; + + @Mock + IOutputNeuralNet targetQNetworkMock; + + + @Before + public void setup() { + when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); + } + @Test public void when_isTerminal_expect_rewardValueAtIdx0() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); List> transitions = new ArrayList>() { { @@ -31,7 +47,7 @@ public class DoubleDQNTest { } }; - DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); + DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); @@ -46,9 +62,7 @@ public class DoubleDQNTest { public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(-1.0); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0)); List> transitions = new ArrayList>() { { @@ -57,7 +71,7 @@ public class DoubleDQNTest { } }; - DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); + DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); @@ -72,9 +86,7 @@ public class DoubleDQNTest { public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(-1.0); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0)); List> transitions = new ArrayList>() { { @@ -87,7 +99,7 @@ public class DoubleDQNTest { } }; - DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); + DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java index 3e3701669..6aead9e76 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java @@ -1,10 +1,13 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; -import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -12,17 +15,31 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +@RunWith(MockitoJUnitRunner.class) public class StandardDQNTest { + + @Mock + IOutputNeuralNet qNetworkMock; + + @Mock + IOutputNeuralNet targetQNetworkMock; + + + @Before + public void setup() { + when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); + when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); + } + + @Test public void when_isTerminal_expect_rewardValueAtIdx0() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); - List> transitions = new ArrayList>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), @@ -30,7 +47,7 @@ public class StandardDQNTest { } }; - StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); + StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); @@ -45,10 +62,6 @@ public class StandardDQNTest { public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); - List> transitions = new ArrayList>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), @@ -56,7 +69,7 @@ public class StandardDQNTest { } }; - StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); + StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); @@ -71,10 +84,6 @@ public class StandardDQNTest { public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); - List> transitions = new ArrayList>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), @@ -86,7 +95,7 @@ public class StandardDQNTest { } }; - StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); + StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java deleted file mode 100644 index ce756aa88..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java +++ /dev/null @@ -1,26 +0,0 @@ -package org.deeplearning4j.rl4j.learning.sync.support; - -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; -import org.deeplearning4j.rl4j.network.dqn.IDQN; - -public class MockTargetQNetworkSource implements TargetQNetworkSource { - - - private final IDQN qNetwork; - private final IDQN targetQNetwork; - - public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) { - this.qNetwork = qNetwork; - this.targetQNetwork = targetQNetwork; - } - - @Override - public IDQN getTargetQNetwork() { - return targetQNetwork; - } - - @Override - public IDQN getQNetwork() { - return qNetwork; - } -}