From fadc2d862229aeab61b7905ef39e241ff2ca61a5 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Thu, 11 Jun 2020 12:37:38 +1000 Subject: [PATCH 01/11] DL4J: Fix 2 JSON issues [WIP] (#490) * Fix MergeVertex serialization for NHWC case Signed-off-by: Alex Black * #8999 Dropout JSON field ignore Signed-off-by: Alex Black --- .../nn/graph/TestComputationGraphNetwork.java | 46 +++++++++++++++++-- .../nn/conf/dropout/Dropout.java | 4 +- .../nn/conf/graph/MergeVertex.java | 2 + 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java index 743e16710..b0cc17376 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/graph/TestComputationGraphNetwork.java @@ -57,10 +57,8 @@ import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.util.ModelSerializer; -import org.junit.AfterClass; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Test; +import org.junit.*; +import org.junit.rules.TemporaryFolder; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.buffer.DataType; @@ -82,6 +80,7 @@ import org.nd4j.common.resources.Resources; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; +import java.io.File; import java.io.IOException; import java.util.*; @@ -91,6 +90,9 @@ import static org.junit.Assert.*; @Slf4j public class TestComputationGraphNetwork extends BaseDL4JTest { + @Rule + public TemporaryFolder testDir = new TemporaryFolder(); + private static ComputationGraphConfiguration getIrisGraphConfiguration() { return new NeuralNetConfiguration.Builder().seed(12345) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() @@ -2177,4 +2179,40 @@ public class TestComputationGraphNetwork extends BaseDL4JTest { INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2); cg.fit(new DataSet(in, label)); } + + @Test + public void testMergeNchw() throws Exception { + ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() + .convolutionMode(ConvolutionMode.Same) + .graphBuilder() + .addInputs("in") + .layer("l0", new ConvolutionLayer.Builder() + .nOut(16) + .kernelSize(2,2).stride(1,1) + .build(), "in") + .layer("l1", new ConvolutionLayer.Builder() + .nOut(8) + .kernelSize(2,2).stride(1,1) + .build(), "in") + .addVertex("merge", new MergeVertex(), "l0", "l1") + .layer("out", new CnnLossLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge") + .setOutputs("out") + .setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC)) + .build(); + + ComputationGraph cg = new ComputationGraph(conf); + cg.init(); + + INDArray[] in = new INDArray[]{Nd4j.rand(DataType.FLOAT, 1, 32, 32, 3)}; + INDArray out = cg.outputSingle(in); + + File dir = testDir.newFolder(); + File f = new File(dir, "net.zip"); + cg.save(f); + + ComputationGraph c2 = ComputationGraph.load(f, true); + INDArray out2 = c2.outputSingle(in); + + assertEquals(out, out2); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java index 46a872fd8..acb6afa2c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/dropout/Dropout.java @@ -66,8 +66,8 @@ import org.nd4j.shade.jackson.annotation.JsonProperty; * @author Alex Black */ @Data -@JsonIgnoreProperties({"mask", "helper", "helperCountFail"}) -@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail"}) +@JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"}) +@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"}) @Slf4j public class Dropout implements IDropout { diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java index 726a68403..c7a4fec63 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/conf/graph/MergeVertex.java @@ -17,6 +17,7 @@ package org.deeplearning4j.nn.conf.graph; +import lombok.Data; import lombok.val; import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.RNNFormat; @@ -38,6 +39,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * -> [numExamples,depth1 + depth2,width,height]}
* @author Alex Black */ +@Data public class MergeVertex extends GraphVertex { protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format From 8733c0c3ed6543c368b7e583309e5af9ff70b59a Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 11 Jun 2020 12:39:14 +0300 Subject: [PATCH 02/11] max_pool with argmax - more data types (#486) Signed-off-by: raver119@gmail.com --- .../generic/nn/pooling/maxpool_with_argmax.cpp | 14 +++++++------- .../ops/declarable/helpers/cpu/max_pooling.cpp | 2 +- .../ops/declarable/helpers/cuda/max_pooling.cu | 2 +- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp index eced3c2b4..b03d19451 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool_with_argmax.cpp @@ -45,18 +45,18 @@ namespace sd { DECLARE_TYPES(max_pool_with_argmax) { getOpDescriptor() ->setAllowedInputTypes(sd::DataType::ANY) - ->setAllowedOutputTypes(0, DataType::INHERIT) - ->setAllowedOutputTypes(1, {ALL_INTS}); + ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedOutputTypes(1, {ALL_INDICES}); } DECLARE_SHAPE_FN(max_pool_with_argmax) { + auto in = inputShape->at(0); + auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64; + auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in)); + auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, dtype)); - auto in = inputShape->at(0); - auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in)); - auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, DataType::INT64)); - - return SHAPELIST(valuesShape, indicesShape); + return SHAPELIST(valuesShape, indicesShape); } } } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp index a458b5eff..ebb9d53fa 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/max_pooling.cpp @@ -73,7 +73,7 @@ namespace helpers { } void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { - BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES); + BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu index 6e70d4510..8c30e510f 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/max_pooling.cu @@ -88,7 +88,7 @@ namespace helpers { void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector const& params, NDArray* indices) { NDArray::prepareSpecialUse({values, indices}, {input}); auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType(); - BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES, INDEXING_TYPES); NDArray::registerSpecialUse({values, indices}, {input}); } From bb0492f47dc42ded580f3ba49a75e15938ddfc5a Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Thu, 11 Jun 2020 20:15:13 +0300 Subject: [PATCH 03/11] R119 random shuffle (#488) * random_shuffle test for Yurii Signed-off-by: raver119@gmail.com * - implementation and testing random_shuffle for vector case (cpu) Signed-off-by: Yurii * - fix bug in random shuffle for cpu Signed-off-by: Yurii * - correct tests for random shuffle and improve alg when inPlace is false Signed-off-by: Yurii * - implementation of random shuffle algorithm for cuda Signed-off-by: Yurii * - split cuda random shuffle alg into separate launches of 2 kernels Signed-off-by: Yurii * - minor corrections in cuda concat kernel Signed-off-by: Yurii Co-authored-by: raver119@gmail.com --- .../declarable/helpers/cpu/randomShuffle.cpp | 198 ++++++++++----- .../ops/declarable/helpers/cuda/concat.cu | 8 +- .../declarable/helpers/cuda/randomShuffle.cu | 228 ++++++++++++++++++ .../ops/declarable/helpers/cuda/transforms.cu | 123 ---------- .../layers_tests/DeclarableOpsTests19.cpp | 1 + .../layers_tests/DeclarableOpsTests5.cpp | 184 +++++++------- .../layers_tests/DeclarableOpsTests9.cpp | 3 +- .../layers_tests/PlaygroundTests.cpp | 99 ++++---- 8 files changed, 509 insertions(+), 335 deletions(-) create mode 100644 libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu diff --git a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp index ea529112d..2ffbfc95f 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/randomShuffle.cpp @@ -16,7 +16,8 @@ // // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 -// +// implementation is based on following article: +// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167 @@ -31,96 +32,167 @@ namespace ops { namespace helpers { ////////////////////////////////////////////////////////////////////////// +// Fisher-Yates shuffle template -void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { +static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) { + + for(Nd4jLong i = len-1; i > 0; --i) { + const Nd4jLong j = rng.relativeLong(ind++) % (i + 1); + if(i != j) + math::nd4j_swap(buff[i*ews], buff[j*ews]); + } +} + +////////////////////////////////////////////////////////////////////////// +// mutual shuffle of two adjacent already shuffled ranges with length len1 and (totLen - len1) correspondingly +template +static void mergeShuffle(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len1, const Nd4jLong& totLen, const Nd4jLong& ews, Nd4jLong ind) { + + Nd4jLong beg = 0; // beginning + Nd4jLong mid = len1; // middle + + while (true) { + if(rng.relativeLong(ind++) % 2) { + if(mid == totLen) + break; + math::nd4j_swap(buff[ews * beg], buff[ews * mid++]); + } else { + if(beg == mid) + break; + } + ++beg; + } + + // fisherYates + while (beg < totLen) { + const Nd4jLong j = rng.relativeLong(ind++) % (beg + 1); + if(beg != j) + math::nd4j_swap(buff[ews * beg], buff[ews * j]); + ++beg; + } +} + +////////////////////////////////////////////////////////////////////////// +template +static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - // check edge cases first - int temp; const int firstDim = input.sizeAt(0); + int temp; + if(input.lengthOf() == 1 || firstDim == 1) { if(!isInplace) output.assign(input); } - else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) { + else if (shape::isCommonVector(input.shapeInfo(), temp)) { - // apply Fisher-Yates shuffle - if(isInplace) { - //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold()) - for(int i = firstDim-1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - if(i == r) - continue; - T t0 = input.t(i); - T t1 = input.t(r); - //math::nd4j_swap(input(i), input(r)); - input.r(i) = t1; - input.r(r) = t0; - } + NDArray* arr = &input; + + if (!isInplace) { + output.assign(input); + arr = &output; } - else { - std::vector indices(firstDim); - std::iota(indices.begin(), indices.end(), 0); - output.p(Nd4jLong(0), input.e(0)); - // FIXME: parallelism!! - for(int i = firstDim-1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - output.r(i) = input.t(indices[r]); - if(i == r) - continue; + const Nd4jLong ews = arr->ews(); - output.r(r) = input.t(indices[i]); - math::nd4j_swap(indices[i], indices[r]); + const Nd4jLong len = arr->lengthOf(); + const Nd4jLong threshold = 1<<22; // this number was deduced from diagram in article + + int power = 0; + while ((len >> power) > threshold) + ++power; + + const Nd4jLong numChunks = 1 << power; + + auto funcFisherYates = PRAGMA_THREADS_FOR { + + for (auto i = start; i < stop; ++i) { + + Nd4jLong offset = (len * i) >> power; + Nd4jLong currLen = ((len * (i + 1)) >> power) - offset; + fisherYates(rng, arr->bufferAsT() + offset*ews, currLen, ews, offset); } - rng.rewindH(firstDim-1); - } + }; + + auto funcMerge = PRAGMA_THREADS_FOR { + + for (int64_t i = start, k = 1; i < stop; i += increment, ++k) { + Nd4jLong offset = len * i >> power; + Nd4jLong len1 = (len * (i + increment/2) >> power) - offset; + Nd4jLong totLen = (len * (i + increment) >> power) - offset; + mergeShuffle(rng, arr->bufferAsT() + offset*ews, len1, totLen, ews, len * k + offset); + } + }; + + samediff::Threads::parallel_for(funcFisherYates, 0, numChunks); + + for (int j = 1; j < numChunks; j += j) + samediff::Threads::parallel_for(funcMerge, 0, numChunks, 2*j); + + // #pragma omp parallel for + // for (uint i = 0; i < numChunks; ++i) { + + // Nd4jLong offset = (len * i) >> power; + // Nd4jLong currLen = ((len * (i + 1)) >> power) - offset; + // fisherYates(rng, arr->bufferAsT() + offset*ews, currLen, ews, offset); + // } + + // for (uint j = 1; j < numChunks; j += j) { + // #pragma omp parallel for + // for (auto i = 0; i < numChunks; i += 2*j) { + // Nd4jLong offset = len * i >> power; + // Nd4jLong len1 = (len * (i + j) >> power) - offset; + // Nd4jLong totLen = (len * (i + 2*j) >> power) - offset; + // mergeShuffle(rng, arr->bufferAsT() + offset*ews, len1, totLen, ews, len * j + offset); + // } + // } + + rng.rewindH((len + 1) * power); } else { - // evaluate sub-arrays list of input array through all dimensions excluding first one - std::vector dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); - auto subArrsListIn = input.allTensorsAlongDimension(dimensions); + auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); - // apply Fisher-Yates shuffle if(isInplace) { - //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().elementwiseThreshold()) - for(int i = firstDim - 1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - if(i == r) - continue; - subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r)); + auto subArrsList = input.allTensorsAlongDimension(dimsToExclude); + + // Fisher-Yates shuffle + for(int i = firstDim - 1; i > 0; --i) { + const int j = rng.relativeInt(i) % (i + 1); + if(i != j) + subArrsList.at(i)->swapUnsafe(*subArrsList.at(j)); } } else { - // evaluate sub-arrays list of output array through all dimensions excluding first one - auto subArrsListOut = output.allTensorsAlongDimension(dimensions); + + auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude); + auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude); + std::vector indices(firstDim); - std::iota(indices.begin(), indices.end(), 0); - bool isZeroShuffled = false; - //PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold()) - for(int i = firstDim - 1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r])); - if(r == 0) - isZeroShuffled = true; - if(i == r) - continue; - subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i])); - math::nd4j_swap(indices[i], indices[r]); - } - if(!isZeroShuffled) - subArrsListOut.at(0)->assign(subArrsListIn.at(0)); + std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1 + + // shuffle indices + fisherYates(rng, indices.data(), firstDim, 1, 0); + + auto func = PRAGMA_THREADS_FOR { + + for (auto i = start; i < stop; ++i) + subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i])); + }; + + samediff::Threads::parallel_for(func, 0, firstDim); } + rng.rewindH(firstDim-1); } - } - void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES); - } +void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { + BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES); +} + } } } + diff --git a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu index cbcd35ffe..400c25f88 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/concat.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/concat.cu @@ -53,7 +53,7 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, const int coords[MAX_RANK]; - for (uint64_t i = tid; i < zLen; i += totalThreads) { + for (Nd4jLong i = tid; i < zLen; i += totalThreads) { shape::index2coords(i, zShapeInfo, coords); const auto zOffset = shape::getOffset(zShapeInfo, coords); @@ -162,9 +162,9 @@ void concat(sd::LaunchContext * context, const std::vector& inAr // } // else { // general (slower) case - const int threadsPerBlock = 256; - const int blocksPerGrid = 512; - const int sharedMem = 512; + const int threadsPerBlock = MAX_NUM_THREADS / 2; + const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; + const int sharedMem = 256; // prepare arrays of pointers on buffers and shapes std::vector hInBuffers(numOfInArrs); diff --git a/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu b/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu new file mode 100644 index 000000000..bb7998e60 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/randomShuffle.cu @@ -0,0 +1,228 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com) +// implemented algorithm is GPU adaptation of algorithm described in following article: +// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167 +// + +#include +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +template +static __global__ void fisherYatesCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power) { + + T* x = reinterpret_cast(vx); + + __shared__ T* shmem, temp; + __shared__ Nd4jLong ind, blockOffset, lenPerBlock; + + if (threadIdx.x == 0) { + extern __shared__ unsigned char sharedMemory[]; + shmem = reinterpret_cast(sharedMemory); + + blockOffset = (len * blockIdx.x) >> power; + lenPerBlock = ((len * (blockIdx.x + 1)) >> power) - blockOffset; + ind = blockOffset; + } + __syncthreads(); + + // copy from global memory to shared memory + if(threadIdx.x < lenPerBlock) + shmem[threadIdx.x] = x[(blockOffset + threadIdx.x) * ews]; + __syncthreads(); + + // *** apply Fisher-Yates shuffle to lenPerBlock number of elements + if (threadIdx.x == 0) { + for(Nd4jLong i = lenPerBlock - 1; i > 0; --i) { + const Nd4jLong j = rng->relativeLong(ind++) % (i + 1); + if(i != j) { + temp = shmem[i]; + shmem[i] = shmem[j]; + shmem[j] = temp; + } + } + } + __syncthreads(); + + // copy from shared memory to global memory + if(threadIdx.x < lenPerBlock) + x[(blockOffset + threadIdx.x) * ews] = shmem[threadIdx.x]; +} + +template +static __global__ void mergeShuffleCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power, const Nd4jLong iterNum) { + + + T* x = reinterpret_cast(vx); + + __shared__ Nd4jLong ind, blockOffset, factor, beg, mid, totLen, iterExp; + + // *** apply mergeShuffle algorithm + if(threadIdx.x == 0) { + + factor = blockIdx.x << iterNum; + iterExp = 1 << (iterNum - 1); + blockOffset = (len * factor) >> power; + mid = ((len * (factor + iterExp)) >> power) - blockOffset; // middle + totLen = ((len * (factor + 2*iterExp)) >> power) - blockOffset; + ind = iterNum * len + blockOffset; + beg = 0; // beginning + + // printf("m %lld, blockIdx.x %lld, factor %lld, blockOffset %lld, mid %lld, totLen %lld \n", m,k,factor,blockOffset,mid,totLen); + + while (true) { + if(rng->relativeLong(ind++) % 2) { + if(mid == totLen) + break; + math::nd4j_swap(x[(blockOffset + beg) * ews], x[(blockOffset + mid++) * ews]); + } else { + if(beg == mid) + break; + } + ++beg; + } + + // Fisher-Yates + while (beg < totLen) { + const Nd4jLong e = rng->relativeLong(ind++) % (beg + 1); + if(beg != e) + math::nd4j_swap(x[(blockOffset + beg) * ews], x[(blockOffset + e) * ews]); + ++beg; + } + } +} + + +////////////////////////////////////////////////////////////////////////// +// Fisher-Yates shuffle +template +static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) { + + for(Nd4jLong i = len-1; i > 0; --i) { + const Nd4jLong j = rng.relativeLong(ind++) % (i + 1); + if(i != j) + math::nd4j_swap(buff[i*ews], buff[j*ews]); + } +} + +////////////////////////////////////////////////////////////////////////// +template +static void randomShuffle_(sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { + + const int firstDim = input.sizeAt(0); + int temp; + + if(input.lengthOf() == 1 || firstDim == 1) { + + if(!isInplace) + output.assign(input); + } + else if (shape::isCommonVector(input.shapeInfo(), temp)) { + + NDArray* arr = &input; + + if (!isInplace) { + output.assign(input); + arr = &output; + } + + const Nd4jLong len = arr->lengthOf(); + + const int threadsPerBlock = MAX_NUM_THREADS; + + int power = 0; + while ((len >> power) > threadsPerBlock) + ++power; + + const int blocksPerGrid = 1 << power; + const int sharedMem = threadsPerBlock * input.sizeOfT() + 256; + + PointersManager manager(context, "NDArray::randomShuffle cuda"); + + sd::graph::RandomGenerator* pRng = reinterpret_cast(manager.replicatePointer(&rng, sizeof(sd::graph::RandomGenerator))); + + NDArray::prepareSpecialUse({arr}, {arr}); + fisherYatesCuda<<getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power); + for (Nd4jLong j = 1, i = 1; j < blocksPerGrid; j += j, ++i) + mergeShuffleCuda<<getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power, i); + NDArray::registerSpecialUse({arr}, {arr}); + + manager.synchronize(); + + rng.rewindH((len + 1) * power); + } + else { + + auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); + + if(isInplace) { + + auto subArrsList = input.allTensorsAlongDimension(dimsToExclude); + + // Fisher-Yates shuffle + for(int i = firstDim - 1; i > 0; --i) { + const int j = rng.relativeInt(i) % (i + 1); + if(i != j) + subArrsList.at(i)->swapUnsafe(*subArrsList.at(j)); + } + } + else { + + auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude); + auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude); + + std::vector indices(firstDim); + std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1 + + // shuffle indices + fisherYates(rng, indices.data(), firstDim, 1, 0); + + auto func = PRAGMA_THREADS_FOR { + + for (auto i = start; i < stop; ++i) + subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i])); + }; + + samediff::Threads::parallel_for(func, 0, firstDim); + } + + rng.rewindH(firstDim-1); + } +} + +///////////////////////////////////////////////////////////////////////// +void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { + BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES); +} + +// BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES); + + + +} +} +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu index 8d7f700dd..80e0e0858 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/transforms.cu @@ -300,129 +300,6 @@ void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray manager.synchronize(); } - template - static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) { - auto tid = blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) { - int r = rng->relativeInt(i) % i; - if (i != r) { - const auto iOffset = shape::getIndexOffset(i, shape); - const auto rOffset = shape::getIndexOffset(r, shape); - T e0 = input[iOffset]; - T e1 = input[rOffset]; - //math::nd4j_swap(input(i), input(r)); - input[iOffset] = e1; - input[rOffset] = e0; - } - } - } - template - static __global__ void fillShuffleKernel(T* input, Nd4jLong const* inputShape, T* output, Nd4jLong const* outputShape, Nd4jLong firstDim, int* indices, sd::graph::RandomGenerator* rng) { - -// PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold()) - auto tid = blockIdx.x * blockDim.x; - auto step = blockDim.x * gridDim.x; - - for(int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) { - int r = rng->relativeInt(i) % i; - output[shape::getIndexOffset(i, outputShape)] = input[shape::getIndexOffset(indices[r], inputShape)]; - if(i != r) { - output[shape::getIndexOffset(r, outputShape)] = input[shape::getIndexOffset(indices[i], inputShape)]; -// output.p(r, input.e(indices[i])); -// math::nd4j_swap(indices[i], indices[r]); - atomicExch(&indices[i], indices[r]); - } - } - - } - ////////////////////////////////////////////////////////////////////////// - template - void randomShuffle_(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - - // check edge cases first - int temp; - const int firstDim = input.sizeAt(0); - auto stream = context->getCudaStream(); - NDArray::prepareSpecialUse({&output}, {&input}); - if(input.lengthOf() == 1 || firstDim == 1) { - if(!isInplace) - output.assign(input); - } - else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) { - - // apply Fisher-Yates shuffle - sd::graph::RandomGenerator* dRandom = nullptr; - cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator)); - cudaMemcpy(dRandom, &rng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice); - T* inputBuf = reinterpret_cast(input.specialBuffer()); - if(isInplace) { - swapShuffleKernel<<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), firstDim, dRandom); - } - else { - std::vector indices(firstDim); - std::iota(indices.begin(), indices.end(), 0); - cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), cudaMemcpyDeviceToDevice); - //output.p(Nd4jLong(0), input.e(0)); - PointersManager pointersManager(context, "helper::randomShuffle_"); - int* indicesDev = reinterpret_cast(pointersManager.replicatePointer(indices.data(), indices.size() * sizeof(int))); - T* outputBuf = reinterpret_cast(output.specialBuffer()); - fillShuffleKernel<<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), outputBuf, output.specialShapeInfo(), firstDim, indicesDev, dRandom); - pointersManager.synchronize(); - } -// rng.rewindH(firstDim - 1); - cudaFree(dRandom); - } - else { - - // evaluate sub-arrays list of input array through all dimensions excluding first one - std::vector dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0}); - auto subArrsListIn = input.allTensorsAlongDimension(dimensions); - - // apply Fisher-Yates shuffle - if(isInplace) { - for(int i = firstDim - 1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - - if(i != r) - subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r)); - } - } - else { - // evaluate sub-arrays list of output array through all dimensions excluding first one - auto subArrsListOut = output.allTensorsAlongDimension(dimensions); - std::vector indices(firstDim); - std::iota(indices.begin(), indices.end(), 0); - bool isZeroShuffled = false; - - for(int i = firstDim - 1; i > 0; --i) { - int r = rng.relativeInt(i) % i; - subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r])); - if(r == 0) - isZeroShuffled = true; - - if(i != r) { - subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i])); - math::nd4j_swap(indices[i], indices[r]); - } - } - if(!isZeroShuffled) - subArrsListOut.at(0)->assign(subArrsListIn.at(0)); - } - rng.rewindH(firstDim-1); - } - NDArray::registerSpecialUse({&output}, {&input}); - - } - - void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { - BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES); - } - - BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES); - - ////////////////////////////////////////////////////////////////////////// void eye(sd::LaunchContext * context, NDArray& output) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index 5f1aefe36..d3d1deed8 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -419,3 +419,4 @@ TEST_F(DeclarableOpsTests19, test_squeeze_1) { auto status = op.execute({&x}, {&e}, {axis}); ASSERT_EQ(Status::OK(), status); } + diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp index 04bb54a61..c68392da1 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -1557,8 +1557,6 @@ TEST_F(DeclarableOpsTests5, trace_test1) { // exp.printIndexedBuffer("EXP TRACE"); // output->printIndexedBuffer("OUT TRACE"); ASSERT_TRUE(exp.equalsTo(output)); - - } ////////////////////////////////////////////////////////////////////// @@ -1575,8 +1573,6 @@ TEST_F(DeclarableOpsTests5, trace_test2) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - - } ////////////////////////////////////////////////////////////////////// @@ -1593,8 +1589,6 @@ TEST_F(DeclarableOpsTests5, trace_test3) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - - } ////////////////////////////////////////////////////////////////////// @@ -1611,8 +1605,6 @@ TEST_F(DeclarableOpsTests5, trace_test4) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - - } ////////////////////////////////////////////////////////////////////// @@ -1629,8 +1621,6 @@ TEST_F(DeclarableOpsTests5, trace_test5) { ASSERT_EQ(Status::OK(), results.status()); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - - } ////////////////////////////////////////////////////////////////////// @@ -1638,22 +1628,15 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) { auto input = NDArrayFactory::create('c', {2, 2, 2}); input.linspace(1); + NDArray exp1 = input.dup(); + NDArray exp2('c',{2,2,2}, {5,6,7,8, 1,2,3,4}, sd::DataType::DOUBLE); sd::ops::random_shuffle op; auto results = op.evaluate({&input}); auto output = results.at(0); - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - - + ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2)); } ////////////////////////////////////////////////////////////////////// @@ -1661,16 +1644,14 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test2) { auto input = NDArrayFactory::create('c', {1, 3, 2}); input.linspace(1); + NDArray exp1 = input.dup(); sd::ops::random_shuffle op; auto results = op.evaluate({&input}); auto output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(input.equalsTo(output)); - - + ASSERT_TRUE(output->equalsTo(exp1)); } ////////////////////////////////////////////////////////////////////// @@ -1678,129 +1659,132 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) { auto input = NDArrayFactory::create('c', {3, 2, 1}); input.linspace(1); + NDArray exp1 = input.dup(); + NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE); + NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE); + NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE); + NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE); + NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE); sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); - auto output = results.at(0); - - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; - - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - - -} -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests5, random_shuffle_test04) { - auto input = NDArrayFactory::create('c', {4}); - input.linspace(1); - - sd::ops::random_shuffle op; - //NDArray* output; auto results = op.evaluate({&input}, {}, {}, {}, {}, true); + ASSERT_EQ(Status::OK(), results.status()); - auto output = &input; //results.at(0); - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; - - ASSERT_TRUE(input.isSameShape(output)); - //ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - - + ASSERT_TRUE(input.equalsTo(exp1) || input.equalsTo(exp2) || input.equalsTo(exp3) + || input.equalsTo(exp4) || input.equalsTo(exp5) || input.equalsTo(exp6)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test4) { - auto input = NDArrayFactory::create('c', {4}); + + auto input = NDArrayFactory::create('c', {3, 2, 1}); input.linspace(1); + NDArray exp1 = input.dup(); + NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE); + NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE); + NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE); + NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE); + NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE); sd::ops::random_shuffle op; - //NDArray* output; auto results = op.evaluate({&input}); - ASSERT_EQ(Status::OK(), results.status()); auto output = results.at(0); - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; - - ASSERT_TRUE(input.isSameShape(output)); - //ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2) || output->equalsTo(exp3) + || output->equalsTo(exp4) || output->equalsTo(exp5) || output->equalsTo(exp6)); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test5) { - - auto input = NDArrayFactory::create('c', {4,1}); + auto input = NDArrayFactory::create('c', {4}); input.linspace(1); sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); auto output = results.at(0); - - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; + // output->printBuffer(); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - + // ASSERT_TRUE(!output->equalsTo(input)); + bool hasDublicates = false; + for(int i = 0; i < output->lengthOf() - 1; ++i) + for(int j = i+1; j < output->lengthOf(); ++j) + if(output->t(i) == output->t(j)) { + hasDublicates = true; + i = output->lengthOf(); + break; + } + ASSERT_TRUE(!hasDublicates); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test6) { - - auto input = NDArrayFactory::create('c', {4,1,1}); + auto input = NDArrayFactory::create('c', {4,1,1}); input.linspace(1); sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); auto output = results.at(0); - bool haveZeros = false; - for(int i = 0; i < output->lengthOf(); ++i) - if(output->e(i) == (float)0.) - haveZeros = true; - ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(!input.equalsTo(output)); - ASSERT_TRUE(!haveZeros); - + // ASSERT_TRUE(!output->equalsTo(input)); + bool hasDublicates = false; + for(int i = 0; i < output->lengthOf() - 1; ++i) + for(int j = i+1; j < output->lengthOf(); ++j) + if(output->t(i) == output->t(j)) { + hasDublicates = true; + i = output->lengthOf(); + break; + } + ASSERT_TRUE(!hasDublicates); } ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests5, random_shuffle_test7) { - - auto input = NDArrayFactory::create('c', {1,4}); + auto input = NDArrayFactory::create('c', {16010}); input.linspace(1); - auto exp = NDArrayFactory::create('c', {1,4}, {1, 2, 3, 4}); sd::ops::random_shuffle op; - auto results = op.evaluate({&input}); + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); auto output = results.at(0); - + // output->printBuffer(); ASSERT_EQ(Status::OK(), results.status()); - ASSERT_TRUE(input.isSameShape(output)); - ASSERT_TRUE(input.equalsTo(output)); + ASSERT_TRUE(!output->equalsTo(input)); + auto vec1 = input.getBufferAsVector(); + auto vec2 = output->getBufferAsVector(); + std::sort(vec2.begin(), vec2.end()); + ASSERT_TRUE(std::equal(vec1.begin(), vec1.end(), vec2.begin())); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, random_shuffle_test8) { + auto input = NDArrayFactory::create('c', {1,4,1}); + input.linspace(1); + NDArray inCopy = input.dup(); + + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(input.equalsTo(inCopy)); + +} + +TEST_F(DeclarableOpsTests5, random_shuffle_test9) { + + auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto z = x.ulike(); + + sd::ops::random_shuffle op; + auto status = op.execute({&x}, {&z}); + ASSERT_EQ(Status::OK(), status); + + auto vec = z.getBufferAsVector(); + std::sort(vec.begin(), vec.end()); + ASSERT_EQ(std::vector({1, 2, 3, 4}), vec); } //////////////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp index 949b43d25..f2bd393e4 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -251,11 +251,10 @@ TEST_F(DeclarableOpsTests9, concat_test1) { auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto output = result.at(0); + // output->printCurrentBuffer(false); ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.equalsTo(output)); - - } //////////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index a8f45cc48..e07a0496d 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -317,7 +317,7 @@ void fill_random(sd::NDArray& arr) { } } - + void testLegacy(bool random) { #if 0 int bases[] = { 3, 2, 4, 5, 7 }; @@ -364,7 +364,7 @@ int k = 4; #endif auto dim = NDArrayFactory::create(dimension); -#if 1 +#if 1 nd4j_printf("C(N:%d K:%d) \n", N, k); dim.printIndexedBuffer("Dimension"); for (int xind : dimension) { @@ -385,7 +385,7 @@ for (int e = 0; e < Loop; e++) { auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); values.emplace_back(outerTime); } - + std::sort(values.begin(), values.end()); nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); @@ -411,7 +411,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order =' constexpr int N = 5; #endif - + for (int i = 0; i < N; i++) { arr_dimensions.push_back(bases[i]); } @@ -451,7 +451,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order =' #endif auto dim = NDArrayFactory::create(dimension); -#if 1 +#if 1 nd4j_printf("C(N:%d K:%d) \n", N, k); dim.printIndexedBuffer("Dimension"); for (int xind : dimension) { @@ -477,14 +477,14 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order =' //check for the correctness NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create('c', output_bases) : NDArrayFactory::create(0); original_argmax(x, dimension, exp); - + #if 0// defined(DEBUG) x.printIndexedBuffer("X"); exp.printIndexedBuffer("Expected"); z->printIndexedBuffer("Z"); #endif - + ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); } @@ -505,7 +505,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfLinspace) { testNewReduction(false, test_corr); } #endif - + TEST_F(PlaygroundTests, ArgMaxPerfRandom) { testNewReduction(true, test_corr); } @@ -513,7 +513,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfRandom) { TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) { testNewReduction(true, test_corr, 'f'); } - + #if !defined(DEBUG) TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) { testLegacy(false); @@ -1062,39 +1062,6 @@ TEST_F(PlaygroundTests, my) { delete variableSpace; } -TEST_F(PlaygroundTests, my) { - - int N = 100; - int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; - int oH=128,oW=128; - - int paddingMode = 1; // 1-SAME, 0-VALID; - int dataFormat = 1; // 1-NHWC, 0-NCHW - - // NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); - // NDArray output('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); - NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); - NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); - // NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] - NDArray weights('c', {oC, iC, kH, kW}, sd::DataType::FLOAT32); - NDArray bias('c', {oC}, sd::DataType::FLOAT32); - - input = 5.; - weights = 3.; - bias = 1.; - - sd::ops::conv2d op; - auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - - auto timeStart = std::chrono::system_clock::now(); - for (int i = 0; i < N; ++i) - err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); - auto timeEnd = std::chrono::system_clock::now(); - auto time = std::chrono::duration_cast ((timeEnd - timeStart) / N).count(); - - printf("time: %i \n", time); -} - /////////////////////////////////////////////////////////////////// TEST_F(PlaygroundTests, lstmLayerCellBp_1) { @@ -1690,6 +1657,52 @@ TEST_F(DeclarableOpsTests15, gru_bp_1) { const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); } +#include +////////////////////////////////////////////////////////////////////// +TEST_F(PlaygroundTests, my) { + + const int N = 10; + + NDArray input('c', {8000000}, sd::DataType::INT32); + input.linspace(1); + NDArray output = input.dup(); + + + sd::graph::RandomGenerator rng; + + sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true); + + // auto timeStart = std::chrono::system_clock::now(); + // for (int i = 0; i < N; ++i) + // sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true); + // auto timeEnd = std::chrono::system_clock::now(); + // auto time = std::chrono::duration_cast ((timeEnd - timeStart) / N).count(); + // printf("time: %i \n", time); + + // bool hasDublicates = false; + // for(int i = 0; i < output.lengthOf() - 1; ++i) + // for(int j = i+1; j < output.lengthOf(); ++j) + // if(output.t(i) == output.t(j)) { + // hasDublicates = true; + // i = output.lengthOf(); + // break; + // } + + ASSERT_TRUE(!input.equalsTo(output)); + + bool hasDublicates = false; + for(int i = 0; i < input.lengthOf() - 1; ++i) + for(int j = i+1; j < input.lengthOf(); ++j) + if(input.t(i) == input.t(j)) { + hasDublicates = true; + i = input.lengthOf(); + break; + } + ASSERT_TRUE(!hasDublicates); +} + + +} + */ - From 9ca679e080c3eadbeadff2ab4f9e8ec6118c70a1 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Tue, 16 Jun 2020 05:43:10 +0400 Subject: [PATCH 04/11] python4j-numpy (#475) * 'initial' * 'impl' * tests * * more tests * scalar fixes * lazy setup jobs * more tests * multithreading wip * multithreading fix * bytebuffer working * nits * inplace exec fixes * attempt linux cpu fix * rollback * list fixes * disable gc * log * bump jcpp + fixes * #8985 GradientSharingTrainingTest ignore for logged issue Signed-off-by: Alex Black * memview fixes * fix? Co-authored-by: Alex Black --- .../train/GradientSharingTrainingTest.java | 12 +- python4j/pom.xml | 11 +- python4j/python4j-core/pom.xml | 1 - .../python4j/PythonContextManager.java | 28 ++ .../eclipse/python4j/PythonExecutioner.java | 44 +-- .../java/org/eclipse/python4j/PythonGIL.java | 4 + .../java/org/eclipse/python4j/PythonJob.java | 23 +- .../org/eclipse/python4j/PythonObject.java | 50 ++- .../org/eclipse/python4j/PythonProcess.java | 127 ++++++++ .../java/org/eclipse/python4j/PythonType.java | 22 ++ .../org/eclipse/python4j/PythonTypes.java | 164 +++++++++- .../org/eclipse/python4j/PythonVariables.java | 47 +++ .../src/test/java/PythonBufferTest.java | 113 +++++++ .../src/test/java/PythonGCTest.java | 2 +- .../src/test/java/PythonJobTest.java | 12 +- python4j/python4j-numpy/pom.xml | 35 ++ .../java/org/eclipse/python4j/NumpyArray.java | 303 ++++++++++++++++++ .../services/org.eclipse.python4j.PythonType | 1 + .../src/test/java/PythonNumpyBasicTest.java | 170 ++++++++++ .../test/java/PythonNumpyCollectionsTest.java | 96 ++++++ .../src/test/java/PythonNumpyGCTest.java | 55 ++++ .../src/test/java/PythonNumpyImportTest.java | 22 ++ .../src/test/java/PythonNumpyJobTest.java | 303 ++++++++++++++++++ .../test/java/PythonNumpyMultiThreadTest.java | 194 +++++++++++ .../java/PythonNumpyServiceLoaderTest.java | 41 +++ 25 files changed, 1828 insertions(+), 52 deletions(-) create mode 100644 python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java create mode 100644 python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariables.java create mode 100644 python4j/python4j-core/src/test/java/PythonBufferTest.java create mode 100644 python4j/python4j-numpy/src/main/java/org/eclipse/python4j/NumpyArray.java create mode 100644 python4j/python4j-numpy/src/main/resources/META-INF/services/org.eclipse.python4j.PythonType create mode 100644 python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java create mode 100644 python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java create mode 100644 python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java create mode 100644 python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java create mode 100644 python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java create mode 100644 python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java create mode 100644 python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java index 68a012b72..c1eff1dce 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java @@ -141,7 +141,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm); sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats()); - System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); +// System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); File f = testDir.newFolder(); DataSetIterator iter = new MnistDataSetIterator(16, true, 12345); int count = 0; @@ -208,10 +208,10 @@ public class GradientSharingTrainingTest extends BaseSparkTest { } INDArray paramsAfter = after.params(); - System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); - System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); - System.out.println(Arrays.toString( - Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); +// System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); +// System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); +// System.out.println(Arrays.toString( +// Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); assertNotEquals(paramsBefore, paramsAfter); @@ -235,7 +235,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { } - @Test + @Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985 public void differentNetsTrainingTest() throws Exception { int batch = 3; diff --git a/python4j/pom.xml b/python4j/pom.xml index 57af8f1bb..1fe50344f 100644 --- a/python4j/pom.xml +++ b/python4j/pom.xml @@ -41,10 +41,14 @@ provided + org.slf4j + slf4j-api + 1.6.6 + ch.qos.logback logback-classic ${logback.version} - test + test junit @@ -62,5 +66,10 @@ jsr305 3.0.2 + + org.slf4j + slf4j-api + 1.6.6 + \ No newline at end of file diff --git a/python4j/python4j-core/pom.xml b/python4j/python4j-core/pom.xml index b429d8272..e74d32392 100644 --- a/python4j/python4j-core/pom.xml +++ b/python4j/python4j-core/pom.xml @@ -39,6 +39,5 @@ cpython-platform ${cpython-platform.version} - \ No newline at end of file diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java index a34d8a239..5675d0864 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java @@ -19,8 +19,10 @@ package org.eclipse.python4j; import javax.lang.model.SourceVersion; +import java.io.Closeable; import java.util.HashSet; import java.util.Set; +import java.util.UUID; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -46,6 +48,31 @@ public class PythonContextManager { init(); } + + public static class Context implements Closeable{ + private final String name; + private final String previous; + private final boolean temp; + public Context(){ + name = "temp_" + UUID.randomUUID().toString().replace("-", "_"); + temp = true; + previous = getCurrentContext(); + setContext(name); + } + public Context(String name){ + this.name = name; + temp = false; + previous = getCurrentContext(); + setContext(name); + } + + @Override + public void close(){ + setContext(previous); + if (temp) deleteContext(name); + } + } + private static void init() { if (init.get()) return; new PythonExecutioner(); @@ -190,6 +217,7 @@ public class PythonContextManager { setContext(tempContext); deleteContext(currContext); setContext(currContext); + deleteContext(tempContext); } /** diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java index 57e1a22ae..542778f76 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java @@ -25,6 +25,7 @@ import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; @@ -42,7 +43,6 @@ public class PythonExecutioner { private final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.eclipse.python4j.path"; private final static String JAVACPP_PYTHON_APPEND_TYPE = "org.eclipse.python4j.path.append"; private final static String DEFAULT_APPEND_TYPE = "before"; - static { init(); } @@ -55,6 +55,11 @@ public class PythonExecutioner { initPythonPath(); PyEval_InitThreads(); Py_InitializeEx(0); + for (PythonType type: PythonTypes.get()){ + type.init(); + } + // Constructors of custom types may contain initialization code that should + // run on the main the thread. } /** @@ -110,6 +115,8 @@ public class PythonExecutioner { getVariables(Arrays.asList(pyVars)); } + + /** * Gets the variable with the given name from the interpreter. * @@ -205,9 +212,9 @@ public class PythonExecutioner { * * @return */ - public static List getAllVariables() { + public static PythonVariables getAllVariables() { PythonGIL.assertThreadSafe(); - List ret = new ArrayList<>(); + PythonVariables ret = new PythonVariables(); PyObject main = PyImport_ImportModule("__main__"); PyObject globals = PyModule_GetDict(main); PyObject keys = PyDict_Keys(globals); @@ -259,7 +266,7 @@ public class PythonExecutioner { * @param inputs * @return */ - public static List execAndReturnAllVariables(String code, List inputs) { + public static PythonVariables execAndReturnAllVariables(String code, List inputs) { setVariables(inputs); simpleExec(getWrappedCode(code)); return getAllVariables(); @@ -271,7 +278,7 @@ public class PythonExecutioner { * @param code * @return */ - public static List execAndReturnAllVariables(String code) { + public static PythonVariables execAndReturnAllVariables(String code) { simpleExec(getWrappedCode(code)); return getAllVariables(); } @@ -279,25 +286,22 @@ public class PythonExecutioner { private static synchronized void initPythonPath() { try { String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY); + + List packagesList = new ArrayList<>(); + packagesList.addAll(Arrays.asList(cachePackages())); + for (PythonType type: PythonTypes.get()){ + packagesList.addAll(Arrays.asList(type.packages())); + } + //// TODO: fix in javacpp + packagesList.add(new File(python.cachePackage(), "site-packages")); + + File[] packages = packagesList.toArray(new File[0]); + if (path == null) { - File[] packages = cachePackages(); - - //// TODO: fix in javacpp - File sitePackagesWindows = new File(python.cachePackage(), "site-packages"); - File[] packages2 = new File[packages.length + 1]; - for (int i = 0; i < packages.length; i++) { - //System.out.println(packages[i].getAbsolutePath()); - packages2[i] = packages[i]; - } - packages2[packages.length] = sitePackagesWindows; - //System.out.println(sitePackagesWindows.getAbsolutePath()); - packages = packages2; - ////////// - Py_SetPath(packages); } else { StringBuffer sb = new StringBuffer(); - File[] packages = cachePackages(); + JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase()); switch (pathAppendValue) { case BEFORE: diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java index 46b3db431..074be294a 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java @@ -90,4 +90,8 @@ public class PythonGIL implements AutoCloseable { PyEval_SaveThread(); PyEval_RestoreThread(mainThreadState); } + + public static boolean locked(){ + return acquired.get(); + } } diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java index cdbb1b81d..0818de890 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java @@ -20,25 +20,29 @@ package org.eclipse.python4j; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; +import lombok.extern.slf4j.Slf4j; import javax.annotation.Nonnull; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; -@Data -@NoArgsConstructor /** * PythonJob is the right abstraction for executing multiple python scripts * in a multi thread stateful environment. The setup-and-run mode allows your * "setup" code (imports, model loading etc) to be executed only once. */ +@Data +@Slf4j public class PythonJob { + private String code; private String name; private String context; - private boolean setupRunMode; + private final boolean setupRunMode; private PythonObject runF; + private final AtomicBoolean setupDone = new AtomicBoolean(false); static { new PythonExecutioner(); @@ -63,7 +67,6 @@ public class PythonJob { if (PythonContextManager.hasContext(context)) { throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!"); } - if (setupRunMode) setup(); } @@ -71,17 +74,18 @@ public class PythonJob { * Clears all variables in current context and calls setup() */ public void clearState(){ - String context = this.context; - PythonContextManager.setContext("main"); - PythonContextManager.deleteContext(context); - this.context = context; + PythonContextManager.setContext(this.context); + PythonContextManager.reset(); + setupDone.set(false); setup(); } public void setup(){ + if (setupDone.get()) return; try (PythonGIL gil = PythonGIL.lock()) { PythonContextManager.setContext(context); PythonObject runF = PythonExecutioner.getVariable("run"); + if (runF == null || runF.isNone() || !Python.callable(runF)) { PythonExecutioner.exec(code); runF = PythonExecutioner.getVariable("run"); @@ -98,10 +102,12 @@ public class PythonJob { if (!setupF.isNone()) { setupF.call(); } + setupDone.set(true); } } public void exec(List inputs, List outputs) { + if (setupRunMode)setup(); try (PythonGIL gil = PythonGIL.lock()) { try (PythonGC _ = PythonGC.watch()) { PythonContextManager.setContext(context); @@ -139,6 +145,7 @@ public class PythonJob { } public List execAndReturnAllVariables(List inputs){ + if (setupRunMode)setup(); try (PythonGIL gil = PythonGIL.lock()) { try (PythonGC _ = PythonGC.watch()) { PythonContextManager.setContext(context); diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java index f8ec17ed9..69252a5f7 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java @@ -147,7 +147,8 @@ public class PythonObject { } PythonObject pyArgs; PythonObject pyKwargs; - if (args == null) { + + if (args == null || args.isEmpty()) { pyArgs = new PythonObject(PyTuple_New(0)); } else { PythonObject argsList = PythonTypes.convert(args); @@ -158,6 +159,7 @@ public class PythonObject { } else { pyKwargs = PythonTypes.convert(kwargs); } + PythonObject ret = new PythonObject( PyObject_Call( nativePythonObject, @@ -165,7 +167,9 @@ public class PythonObject { pyKwargs == null ? null : pyKwargs.nativePythonObject ) ); + PythonGC.keep(ret); + return ret; } @@ -241,4 +245,48 @@ public class PythonObject { PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject); } + + public PythonObject abs(){ + return new PythonObject(PyNumber_Absolute(nativePythonObject)); + } + public PythonObject add(PythonObject pythonObject){ + return new PythonObject(PyNumber_Add(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject sub(PythonObject pythonObject){ + return new PythonObject(PyNumber_Subtract(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject mod(PythonObject pythonObject){ + return new PythonObject(PyNumber_Divmod(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject mul(PythonObject pythonObject){ + return new PythonObject(PyNumber_Multiply(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject trueDiv(PythonObject pythonObject){ + return new PythonObject(PyNumber_TrueDivide(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject floorDiv(PythonObject pythonObject){ + return new PythonObject(PyNumber_FloorDivide(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject matMul(PythonObject pythonObject){ + return new PythonObject(PyNumber_MatrixMultiply(nativePythonObject, pythonObject.nativePythonObject)); + } + + public void addi(PythonObject pythonObject){ + PyNumber_InPlaceAdd(nativePythonObject, pythonObject.nativePythonObject); + } + public void subi(PythonObject pythonObject){ + PyNumber_InPlaceSubtract(nativePythonObject, pythonObject.nativePythonObject); + } + public void muli(PythonObject pythonObject){ + PyNumber_InPlaceMultiply(nativePythonObject, pythonObject.nativePythonObject); + } + public void trueDivi(PythonObject pythonObject){ + PyNumber_InPlaceTrueDivide(nativePythonObject, pythonObject.nativePythonObject); + } + public void floorDivi(PythonObject pythonObject){ + PyNumber_InPlaceFloorDivide(nativePythonObject, pythonObject.nativePythonObject); + } + public void matMuli(PythonObject pythonObject){ + PyNumber_InPlaceMatrixMultiply(nativePythonObject, pythonObject.nativePythonObject); + } } diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java new file mode 100644 index 000000000..0ca17fb49 --- /dev/null +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java @@ -0,0 +1,127 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.eclipse.python4j; + +import org.apache.commons.io.IOUtils; +import org.bytedeco.javacpp.Loader; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +public class PythonProcess { + private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class); + public static String runAndReturn(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + ProcessBuilder pb = new ProcessBuilder(allArgs); + Process process = pb.start(); + String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8); + process.waitFor(); + return out; + + } + + public static void run(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + ProcessBuilder pb = new ProcessBuilder(allArgs); + pb.inheritIO().start().waitFor(); + } + public static void pipInstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "install", packageName); + }catch(Exception e){ + throw new PythonException("Error installing package " + packageName, e); + } + + } + + public static void pipInstall(String packageName, String version){ + pipInstall(packageName + "==" + version); + } + + public static void pipUninstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "uninstall", packageName); + }catch(Exception e){ + throw new PythonException("Error uninstalling package " + packageName, e); + } + + } + public static void pipInstallFromGit(String gitRepoUrl){ + if (!gitRepoUrl.contains("://")){ + gitRepoUrl = "git://" + gitRepoUrl; + } + try{ + run("-m", "pip", "install", "git+", gitRepoUrl); + }catch(Exception e){ + throw new PythonException("Error installing package from " + gitRepoUrl, e); + } + + } + + public static String getPackageVersion(String packageName){ + String out; + try{ + out = runAndReturn("-m", "pip", "show", packageName); + } catch (Exception e){ + throw new PythonException("Error finding version for package " + packageName, e); + } + + if (!out.contains("Version: ")){ + throw new PythonException("Can't find package " + packageName); + } + String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0]; + return pkgVersion; + } + + public static boolean isPackageInstalled(String packageName){ + try{ + String out = runAndReturn("-m", "pip", "show", packageName); + return !out.isEmpty(); + }catch (Exception e){ + throw new PythonException("Error checking if package is installed: " +packageName, e); + } + + } + + public static void pipInstallFromRequirementsTxt(String path){ + try{ + run("-m", "pip", "install","-r", path); + }catch (Exception e){ + throw new PythonException("Error installing packages from " + path, e); + } + } + + public static void pipInstallFromSetupScript(String path, boolean inplace){ + + try{ + run(path, inplace?"develop":"install"); + }catch (Exception e){ + throw new PythonException("Error installing package from " + path, e); + } + + } + +} \ No newline at end of file diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java index b4806aa37..47b725cd5 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java @@ -17,6 +17,8 @@ package org.eclipse.python4j; +import java.io.File; + public abstract class PythonType { private final String name; @@ -43,5 +45,25 @@ public abstract class PythonType { return name; } + @Override + public boolean equals(Object obj){ + if (!(obj instanceof PythonType)){ + return false; + } + PythonType other = (PythonType)obj; + return this.getClass().equals(other.getClass()) && this.name.equals(other.name); + } + + public PythonObject pythonType(){ + return null; + } + + public File[] packages(){ + return new File[0]; + } + + public void init(){ //not to be called from constructor + + } } diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java index 0dc20f712..cd7ac7d7c 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java @@ -18,7 +18,16 @@ package org.eclipse.python4j; import org.bytedeco.cpython.PyObject; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Loader; +import org.bytedeco.javacpp.Pointer; +import sun.misc.Unsafe; +import sun.nio.ch.DirectBuffer; +import java.lang.reflect.Field; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.*; import static org.bytedeco.cpython.global.python.*; @@ -28,7 +37,7 @@ public class PythonTypes { private static List getPrimitiveTypes() { - return Arrays.asList(STR, INT, FLOAT, BOOL); + return Arrays.asList(STR, INT, FLOAT, BOOL, MEMORYVIEW); } private static List getCollectionTypes() { @@ -36,8 +45,13 @@ public class PythonTypes { } private static List getExternalTypes() { - //TODO service loader - return new ArrayList<>(); + List ret = new ArrayList<>(); + ServiceLoader sl = ServiceLoader.load(PythonType.class); + Iterator iter = sl.iterator(); + while (iter.hasNext()) { + ret.add(iter.next()); + } + return ret; } public static List get() { @@ -48,15 +62,17 @@ public class PythonTypes { return ret; } - public static PythonType get(String name) { + public static PythonType get(String name) { for (PythonType pt : get()) { if (pt.getName().equals(name)) { // TODO use map instead? return pt; } + } throw new PythonException("Unknown python type: " + name); } + public static PythonType getPythonTypeForJavaObject(Object javaObject) { for (PythonType pt : get()) { if (pt.accepts(javaObject)) { @@ -66,7 +82,7 @@ public class PythonTypes { throw new PythonException("Unable to find python type for java type: " + javaObject.getClass()); } - public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) { + public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) { PyObject pyType = PyObject_Type(pythonObject.getNativePythonObject()); try { String pyTypeStr = PythonTypes.STR.toJava(new PythonObject(pyType, false)); @@ -75,6 +91,14 @@ public class PythonTypes { String pyTypeStr2 = ""; if (pyTypeStr.equals(pyTypeStr2)) { return pt; + } else { + try (PythonGC gc = PythonGC.watch()) { + PythonObject pyType2 = pt.pythonType(); + if (pyType2 != null && Python.isinstance(pythonObject, pyType2)) { + return pt; + } + } + } } throw new PythonException("Unable to find converter for python object of type " + pyTypeStr); @@ -212,12 +236,49 @@ public class PythonTypes { public static final PythonType LIST = new PythonType("list", List.class) { + @Override + public boolean accepts(Object javaObject) { + return (javaObject instanceof List || javaObject.getClass().isArray()); + } + @Override public List adapt(Object javaObject) { if (javaObject instanceof List) { return (List) javaObject; - } else if (javaObject instanceof Object[]) { - return Arrays.asList((Object[]) javaObject); + } else if (javaObject.getClass().isArray()) { + List ret = new ArrayList<>(); + if (javaObject instanceof Object[]) { + Object[] arr = (Object[]) javaObject; + return new ArrayList<>(Arrays.asList(arr)); + } else if (javaObject instanceof short[]) { + short[] arr = (short[]) javaObject; + for (short x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof int[]) { + int[] arr = (int[]) javaObject; + for (int x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof long[]) { + long[] arr = (long[]) javaObject; + for (long x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof float[]) { + float[] arr = (float[]) javaObject; + for (float x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof double[]) { + double[] arr = (double[]) javaObject; + for (double x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof boolean[]) { + boolean[] arr = (boolean[]) javaObject; + for (boolean x : arr) ret.add(x); + return ret; + } else { + throw new PythonException("Unsupported array type: " + javaObject.getClass().toString()); + } + + } else { throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to List"); } @@ -327,7 +388,13 @@ public class PythonTypes { } Object v = javaObject.get(k); PythonObject pyVal; - pyVal = PythonTypes.convert(v); + if (v instanceof PythonObject) { + pyVal = (PythonObject) v; + } else if (v instanceof PyObject) { + pyVal = new PythonObject((PyObject) v); + } else { + pyVal = PythonTypes.convert(v); + } int errCode = PyDict_SetItem(pyDict, pyKey.getNativePythonObject(), pyVal.getNativePythonObject()); if (errCode != 0) { String keyStr = pyKey.toString(); @@ -341,4 +408,85 @@ public class PythonTypes { return new PythonObject(pyDict); } }; + + + public static final PythonType MEMORYVIEW = new PythonType("memoryview", BytePointer.class) { + @Override + public BytePointer toJava(PythonObject pythonObject) { + try (PythonGC gc = PythonGC.watch()) { + if (!(Python.isinstance(pythonObject, Python.memoryviewType()))) { + throw new PythonException("Expected memoryview. Received: " + pythonObject); + } + PythonObject pySize = Python.len(pythonObject); + PythonObject ctypes = Python.importModule("ctypes"); + PythonObject charType = ctypes.attr("c_char"); + PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(), + pySize.getNativePythonObject())); + PythonObject fromBuffer = charArrayType.attr("from_buffer"); + if (pythonObject.attr("readonly").toBoolean()) { + pythonObject = Python.bytearray(pythonObject); + } + PythonObject arr = fromBuffer.call(pythonObject); + PythonObject cast = ctypes.attr("cast"); + PythonObject voidPtrType = ctypes.attr("c_void_p"); + PythonObject voidPtr = cast.call(arr, voidPtrType); + long address = voidPtr.attr("value").toLong(); + long size = pySize.toLong(); + try { + Field addressField = Buffer.class.getDeclaredField("address"); + addressField.setAccessible(true); + Field capacityField = Buffer.class.getDeclaredField("capacity"); + capacityField.setAccessible(true); + ByteBuffer buff = ByteBuffer.allocateDirect(0).order(ByteOrder.nativeOrder()); + addressField.setLong(buff, address); + capacityField.setInt(buff, (int) size); + BytePointer ret = new BytePointer(buff); + ret.limit(size); + return ret; + + } catch (Exception e) { + throw new RuntimeException(e); + } + + } + } + + @Override + public PythonObject toPython(BytePointer javaObject) { + long address = javaObject.address(); + long size = javaObject.limit(); + try (PythonGC gc = PythonGC.watch()) { + PythonObject ctypes = Python.importModule("ctypes"); + PythonObject charType = ctypes.attr("c_char"); + PythonObject pySize = new PythonObject(size); + PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(), + pySize.getNativePythonObject())); + PythonObject fromAddress = charArrayType.attr("from_address"); + PythonObject arr = fromAddress.call(new PythonObject(address)); + PythonObject memoryView = Python.memoryview(arr).attr("cast").call("b"); + PythonGC.keep(memoryView); + return memoryView; + } + + } + + @Override + public boolean accepts(Object javaObject) { + return javaObject instanceof Pointer || javaObject instanceof DirectBuffer; + } + + @Override + public BytePointer adapt(Object javaObject) { + if (javaObject instanceof BytePointer) { + return (BytePointer) javaObject; + } else if (javaObject instanceof Pointer) { + return new BytePointer((Pointer) javaObject); + } else if (javaObject instanceof DirectBuffer) { + return new BytePointer((ByteBuffer) javaObject); + } else { + throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to BytePointer"); + } + } + }; + } diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariables.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariables.java new file mode 100644 index 000000000..32ae0b2f5 --- /dev/null +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariables.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.eclipse.python4j; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Some syntax sugar for lookup by name + */ +public class PythonVariables extends ArrayList { + public PythonVariable get(String variableName) { + for (PythonVariable pyVar: this){ + if (pyVar.getName().equals(variableName)){ + return pyVar; + } + } + return null; + } + + public boolean add(String variableName, PythonType variableType, Object value){ + return this.add(new PythonVariable<>(variableName, variableType, value)); + } + + public PythonVariables(PythonVariable... variables){ + this(Arrays.asList(variables)); + } + public PythonVariables(List list){ + super(); + addAll(list); + } +} diff --git a/python4j/python4j-core/src/test/java/PythonBufferTest.java b/python4j/python4j-core/src/test/java/PythonBufferTest.java new file mode 100644 index 000000000..c59b86c15 --- /dev/null +++ b/python4j/python4j-core/src/test/java/PythonBufferTest.java @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Loader; +import org.eclipse.python4j.*; +import org.junit.Assert; +import org.junit.Test; +import sun.nio.ch.DirectBuffer; + +import javax.annotation.concurrent.NotThreadSafe; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.util.*; + +@NotThreadSafe +public class PythonBufferTest { + + @Test + public void testBuffer() { + ByteBuffer buff = ByteBuffer.allocateDirect(3); + buff.put((byte) 97); + buff.put((byte) 98); + buff.put((byte) 99); + buff.rewind(); + + BytePointer bp = new BytePointer(buff); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("buff", PythonTypes.MEMORYVIEW, buff)); + + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); + outputs.add(new PythonVariable<>("s2", PythonTypes.STR)); + + String code = "s1 = ''.join(chr(c) for c in buff)\nbuff[2] += 2\ns2 = ''.join(chr(c) for c in buff)"; + + PythonExecutioner.exec(code, inputs, outputs); + Assert.assertEquals("abc", outputs.get(0).getValue()); + Assert.assertEquals("abe", outputs.get(1).getValue()); + Assert.assertEquals(101, buff.get(2)); + + } + @Test + public void testBuffer2() { + ByteBuffer buff = ByteBuffer.allocateDirect(3); + buff.put((byte) 97); + buff.put((byte) 98); + buff.put((byte) 99); + buff.rewind(); + + BytePointer bp = new BytePointer(buff); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("buff", PythonTypes.MEMORYVIEW, bp)); + + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); + outputs.add(new PythonVariable<>("s2", PythonTypes.STR)); + + String code = "s1 = ''.join(chr(c) for c in buff)\nbuff[2] += 2\ns2 = ''.join(chr(c) for c in buff)"; + + PythonExecutioner.exec(code, inputs, outputs); + Assert.assertEquals("abc", outputs.get(0).getValue()); + Assert.assertEquals("abe", outputs.get(1).getValue()); + Assert.assertEquals(101, buff.get(2)); + + } + + @Test + public void testBuffer3() { + ByteBuffer buff = ByteBuffer.allocateDirect(3); + buff.put((byte) 97); + buff.put((byte) 98); + buff.put((byte) 99); + buff.rewind(); + + BytePointer bp = new BytePointer(buff); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("buff", PythonTypes.MEMORYVIEW, bp)); + + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); + outputs.add(new PythonVariable<>("s2", PythonTypes.STR)); + outputs.add(new PythonVariable<>("buff2", PythonTypes.MEMORYVIEW)); + String code = "s1 = ''.join(chr(c) for c in buff)\nbuff[2] += 2\ns2 = ''.join(chr(c) for c in buff)\nbuff2=buff[1:]"; + PythonExecutioner.exec(code, inputs, outputs); + + Assert.assertEquals("abc", outputs.get(0).getValue()); + Assert.assertEquals("abe", outputs.get(1).getValue()); + Assert.assertEquals(101, buff.get(2)); + BytePointer outBuffer = (BytePointer) outputs.get(2).getValue(); + Assert.assertEquals(2, outBuffer.capacity()); + Assert.assertEquals((byte)98, outBuffer.get(0)); + Assert.assertEquals((byte)101, outBuffer.get(1)); + + } +} \ No newline at end of file diff --git a/python4j/python4j-core/src/test/java/PythonGCTest.java b/python4j/python4j-core/src/test/java/PythonGCTest.java index f8c6ecba5..80b2e7f3c 100644 --- a/python4j/python4j-core/src/test/java/PythonGCTest.java +++ b/python4j/python4j-core/src/test/java/PythonGCTest.java @@ -49,6 +49,6 @@ public class PythonGCTest { PythonObject pyObjCount3 = Python.len(getObjects.call()); long objCount3 = pyObjCount3.toLong(); diff = objCount3 - objCount2; - Assert.assertEquals(2, diff);// 2 objects created during function call + Assert.assertTrue(diff <= 2);// 2 objects created during function call } } diff --git a/python4j/python4j-core/src/test/java/PythonJobTest.java b/python4j/python4j-core/src/test/java/PythonJobTest.java index 016045a25..b0f4233c9 100644 --- a/python4j/python4j-core/src/test/java/PythonJobTest.java +++ b/python4j/python4j-core/src/test/java/PythonJobTest.java @@ -30,7 +30,7 @@ import static org.junit.Assert.assertEquals; public class PythonJobTest { @Test - public void testPythonJobBasic() throws Exception{ + public void testPythonJobBasic(){ PythonContextManager.deleteNonMainContexts(); String code = "c = a + b"; @@ -65,7 +65,7 @@ public class PythonJobTest { } @Test - public void testPythonJobReturnAllVariables()throws Exception{ + public void testPythonJobReturnAllVariables(){ PythonContextManager.deleteNonMainContexts(); String code = "c = a + b"; @@ -101,7 +101,7 @@ public class PythonJobTest { @Test - public void testMultiplePythonJobsParallel()throws Exception{ + public void testMultiplePythonJobsParallel(){ PythonContextManager.deleteNonMainContexts(); String code1 = "c = a + b"; PythonJob job1 = new PythonJob("job1", code1, false); @@ -150,7 +150,7 @@ public class PythonJobTest { @Test - public void testPythonJobSetupRun()throws Exception{ + public void testPythonJobSetupRun(){ PythonContextManager.deleteNonMainContexts(); String code = "five=None\n" + @@ -189,7 +189,7 @@ public class PythonJobTest { } @Test - public void testPythonJobSetupRunAndReturnAllVariables()throws Exception{ + public void testPythonJobSetupRunAndReturnAllVariables(){ PythonContextManager.deleteNonMainContexts(); String code = "five=None\n" + "c=None\n"+ @@ -225,7 +225,7 @@ public class PythonJobTest { } @Test - public void testMultiplePythonJobsSetupRunParallel()throws Exception{ + public void testMultiplePythonJobsSetupRunParallel(){ PythonContextManager.deleteNonMainContexts(); String code1 = "five=None\n" + diff --git a/python4j/python4j-numpy/pom.xml b/python4j/python4j-numpy/pom.xml index 527a9343f..bcce739ce 100644 --- a/python4j/python4j-numpy/pom.xml +++ b/python4j/python4j-numpy/pom.xml @@ -28,15 +28,50 @@ ${nd4j.version} test + + org.eclipse + python4j-core + 1.0.0-SNAPSHOT + test-nd4j-native + + + org.nd4j + nd4j-native + ${nd4j.version} + test + + + org.deeplearning4j + dl4j-test-resources + ${nd4j.version} + test + + + test-nd4j-cuda-10.2 + + + org.nd4j + nd4j-cuda-10.1 + ${nd4j.version} + test + + + org.deeplearning4j + dl4j-test-resources + ${nd4j.version} + test + + + \ No newline at end of file diff --git a/python4j/python4j-numpy/src/main/java/org/eclipse/python4j/NumpyArray.java b/python4j/python4j-numpy/src/main/java/org/eclipse/python4j/NumpyArray.java new file mode 100644 index 000000000..66fb76d23 --- /dev/null +++ b/python4j/python4j-numpy/src/main/java/org/eclipse/python4j/NumpyArray.java @@ -0,0 +1,303 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.eclipse.python4j; + +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ArrayUtils; +import org.bytedeco.cpython.PyObject; +import org.bytedeco.cpython.PyTypeObject; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.SizeTPointer; +import org.bytedeco.numpy.PyArrayObject; +import org.bytedeco.numpy.global.numpy; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.concurrency.AffinityManager; +import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.api.memory.MemoryWorkspaceManager; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.nativeblas.NativeOps; +import org.nd4j.nativeblas.NativeOpsHolder; + +import java.io.File; +import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.bytedeco.cpython.global.python.*; +import static org.bytedeco.cpython.global.python.Py_DecRef; +import static org.bytedeco.numpy.global.numpy.*; +import static org.bytedeco.numpy.global.numpy.NPY_ARRAY_CARRAY; +import static org.bytedeco.numpy.global.numpy.PyArray_Type; + +@Slf4j +public class NumpyArray extends PythonType { + + public static final NumpyArray INSTANCE; + private static final AtomicBoolean init = new AtomicBoolean(false); + private static final Map cache = new HashMap<>(); + + static { + new PythonExecutioner(); + INSTANCE = new NumpyArray(); + } + + @Override + public File[] packages(){ + try{ + return new File[]{numpy.cachePackage()}; + }catch(Exception e){ + throw new PythonException(e); + } + + } + + public synchronized void init() { + if (init.get()) return; + init.set(true); + if (PythonGIL.locked()) { + throw new PythonException("Can not initialize numpy - GIL already acquired."); + } + int err = numpy._import_array(); + if (err < 0){ + System.out.println("Numpy import failed!"); + throw new PythonException("Numpy import failed!"); + } + } + + public NumpyArray() { + super("numpy.ndarray", INDArray.class); + + } + + @Override + public INDArray toJava(PythonObject pythonObject) { + log.info("Converting PythonObject to INDArray..."); + PyObject np = PyImport_ImportModule("numpy"); + PyObject ndarray = PyObject_GetAttrString(np, "ndarray"); + if (PyObject_IsInstance(pythonObject.getNativePythonObject(), ndarray) != 1) { + Py_DecRef(ndarray); + Py_DecRef(np); + throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array."); + } + Py_DecRef(ndarray); + Py_DecRef(np); + PyArrayObject npArr = new PyArrayObject(pythonObject.getNativePythonObject()); + long[] shape = new long[PyArray_NDIM(npArr)]; + SizeTPointer shapePtr = PyArray_SHAPE(npArr); + if (shapePtr != null) + shapePtr.get(shape, 0, shape.length); + long[] strides = new long[shape.length]; + SizeTPointer stridesPtr = PyArray_STRIDES(npArr); + if (stridesPtr != null) + stridesPtr.get(strides, 0, strides.length); + int npdtype = PyArray_TYPE(npArr); + + DataType dtype; + switch (npdtype) { + case NPY_DOUBLE: + dtype = DataType.DOUBLE; + break; + case NPY_FLOAT: + dtype = DataType.FLOAT; + break; + case NPY_SHORT: + dtype = DataType.SHORT; + break; + case NPY_INT: + dtype = DataType.INT32; + break; + case NPY_LONG: + dtype = DataType.INT64; + break; + case NPY_UINT: + dtype = DataType.UINT32; + break; + case NPY_BYTE: + dtype = DataType.INT8; + break; + case NPY_UBYTE: + dtype = DataType.UINT8; + break; + case NPY_BOOL: + dtype = DataType.BOOL; + break; + case NPY_HALF: + dtype = DataType.FLOAT16; + break; + case NPY_LONGLONG: + dtype = DataType.INT64; + break; + case NPY_USHORT: + dtype = DataType.UINT16; + break; + case NPY_ULONG: + case NPY_ULONGLONG: + dtype = DataType.UINT64; + break; + default: + throw new PythonException("Unsupported array data type: " + npdtype); + } + long size = 1; + for (int i = 0; i < shape.length; size *= shape[i++]) ; + + INDArray ret; + long address = PyArray_DATA(npArr).address(); + String key = address + "_" + size + "_" + dtype; + DataBuffer buff = cache.get(key); + if (buff == null) { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address); + ptr = ptr.limit(size); + ptr = ptr.capacity(size); + buff = Nd4j.createBuffer(ptr, size, dtype); + cache.put(key, buff); + } + } + int elemSize = buff.getElementSize(); + long[] nd4jStrides = new long[strides.length]; + for (int i = 0; i < strides.length; i++) { + nd4jStrides[i] = strides[i] / elemSize; + } + ret = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype); + Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.HOST); + log.info("Done."); + return ret; + + + } + + @Override + public PythonObject toPython(INDArray indArray) { + log.info("Converting INDArray to PythonObject..."); + DataType dataType = indArray.dataType(); + DataBuffer buff = indArray.data(); + String key = buff.pointer().address() + "_" + buff.length() + "_" + dataType; + cache.put(key, buff); + int numpyType; + String ctype; + switch (dataType) { + case DOUBLE: + numpyType = NPY_DOUBLE; + ctype = "c_double"; + break; + case FLOAT: + case BFLOAT16: + numpyType = NPY_FLOAT; + ctype = "c_float"; + break; + case SHORT: + numpyType = NPY_SHORT; + ctype = "c_short"; + break; + case INT: + numpyType = NPY_INT; + ctype = "c_int"; + break; + case LONG: + numpyType = NPY_INT64; + ctype = "c_int64"; + break; + case UINT16: + numpyType = NPY_USHORT; + ctype = "c_uint16"; + break; + case UINT32: + numpyType = NPY_UINT; + ctype = "c_uint"; + break; + case UINT64: + numpyType = NPY_UINT64; + ctype = "c_uint64"; + break; + case BOOL: + numpyType = NPY_BOOL; + ctype = "c_bool"; + break; + case BYTE: + numpyType = NPY_BYTE; + ctype = "c_byte"; + break; + case UBYTE: + numpyType = NPY_UBYTE; + ctype = "c_ubyte"; + break; + case HALF: + numpyType = NPY_HALF; + ctype = "c_short"; + break; + default: + throw new RuntimeException("Unsupported dtype: " + dataType); + } + + long[] shape = indArray.shape(); + INDArray inputArray = indArray; + if (dataType == DataType.BFLOAT16) { + log.warn("Creating copy of array as bfloat16 is not supported by numpy."); + inputArray = indArray.castTo(DataType.FLOAT); + } + + //Sync to host memory in the case of CUDA, before passing the host memory pointer to Python + + Nd4j.getAffinityManager().ensureLocation(inputArray, AffinityManager.Location.HOST); + + // PyArray_Type() call causes jvm crash in linux cpu if GIL is acquired by non main thread. + // Using Interpreter for now: + +// try(PythonContextManager.Context context = new PythonContextManager.Context("__np_array_converter")){ +// log.info("Stringing exec..."); +// String code = "import ctypes\nimport numpy as np\n" + +// "cArr = (ctypes." + ctype + "*" + indArray.length() + ")"+ +// ".from_address(" + indArray.data().pointer().address() + ")\n"+ +// "npArr = np.frombuffer(cArr, dtype=" + ((numpyType == NPY_HALF) ? "'half'" : "ctypes." + ctype)+ +// ").reshape(" + Arrays.toString(indArray.shape()) + ")"; +// PythonExecutioner.exec(code); +// log.info("exec done."); +// PythonObject ret = PythonExecutioner.getVariable("npArr"); +// Py_IncRef(ret.getNativePythonObject()); +// return ret; +// +// } + log.info("NUMPY: PyArray_Type()"); + PyTypeObject pyTypeObject = PyArray_Type(); + + + log.info("NUMPY: PyArray_New()"); + PyObject npArr = PyArray_New(pyTypeObject, shape.length, new SizeTPointer(shape), + numpyType, null, + inputArray.data().addressPointer(), + 0, NPY_ARRAY_CARRAY, null); + log.info("Done."); + return new PythonObject(npArr); + } + + @Override + public boolean accepts(Object javaObject) { + return javaObject instanceof INDArray; + } + + @Override + public INDArray adapt(Object javaObject) { + if (javaObject instanceof INDArray) { + return (INDArray) javaObject; + } + throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to INDArray"); + } +} diff --git a/python4j/python4j-numpy/src/main/resources/META-INF/services/org.eclipse.python4j.PythonType b/python4j/python4j-numpy/src/main/resources/META-INF/services/org.eclipse.python4j.PythonType new file mode 100644 index 000000000..ae4d4640b --- /dev/null +++ b/python4j/python4j-numpy/src/main/resources/META-INF/services/org.eclipse.python4j.PythonType @@ -0,0 +1 @@ +org.eclipse.python4j.NumpyArray \ No newline at end of file diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java new file mode 100644 index 000000000..b7bd838b5 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java @@ -0,0 +1,170 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +import org.eclipse.python4j.*; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.concurrency.AffinityManager; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.nativeblas.OpaqueDataBuffer; + +import javax.annotation.concurrent.NotThreadSafe; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +@NotThreadSafe +@RunWith(Parameterized.class) +public class PythonNumpyBasicTest { + private DataType dataType; + private long[] shape; + + public PythonNumpyBasicTest(DataType dataType, long[] shape, String dummyArg) { + this.dataType = dataType; + this.shape = shape; + } + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}, shape={2}") + public static Collection params() { + DataType[] types = new DataType[] { + DataType.BOOL, + DataType.FLOAT16, + DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, + DataType.INT8, + DataType.INT16, + DataType.INT32, + DataType.INT64, + DataType.UINT8, + DataType.UINT16, + DataType.UINT32, + DataType.UINT64 + }; + + long[][] shapes = new long[][]{ + new long[]{2, 3}, + new long[]{3}, + new long[]{1}, + new long[]{} // scalar + }; + + + List ret = new ArrayList<>(); + for (DataType type: types){ + for (long[] shape: shapes){ + ret.add(new Object[]{type, shape, Arrays.toString(shape)}); + } + } + return ret; + } + + @Test + public void testConversion(){ + INDArray arr = Nd4j.zeros(dataType, shape); + PythonObject npArr = PythonTypes.convert(arr); + INDArray arr2 = PythonTypes.getPythonTypeForPythonObject(npArr).toJava(npArr); + if (dataType == DataType.BFLOAT16){ + arr = arr.castTo(DataType.FLOAT); + } + Assert.assertEquals(arr,arr2); + } + + + @Test + public void testExecution(){ + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, shape); + INDArray y = Nd4j.zeros(dataType, shape); + INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2)); + z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z; + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + List outputs = new ArrayList<>(); + PythonVariable output = new PythonVariable<>("z", arrType); + outputs.add(output); + String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)"; + if (shape.length == 0){ // scalar special case + code += "\nimport numpy as np\nz = np.asarray(float(z), dtype=x.dtype)"; + } + PythonExecutioner.exec(code, inputs, outputs); + INDArray z2 = output.getValue(); + + Assert.assertEquals(z.dataType(), z2.dataType()); + Assert.assertEquals(z, z2); + + } + + + @Test + public void testInplaceExecution(){ + if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return; + if (shape.length == 0) return; + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, shape); + INDArray y = Nd4j.zeros(dataType, shape); + INDArray z = x.mul(y.add(2)); + // Nd4j.getAffinityManager().ensureLocation(z, AffinityManager.Location.HOST); + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + List outputs = new ArrayList<>(); + PythonVariable output = new PythonVariable<>("x", arrType); + outputs.add(output); + String code = "x *= y + 2"; + PythonExecutioner.exec(code, inputs, outputs); + INDArray z2 = output.getValue(); + Assert.assertEquals(x.dataType(), z2.dataType()); + Assert.assertEquals(z.dataType(), z2.dataType()); + Assert.assertEquals(x, z2); + Assert.assertEquals(z, z2); + Assert.assertEquals(x.data().pointer().address(), z2.data().pointer().address()); + if("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){ + Assert.assertEquals(getDeviceAddress(x), getDeviceAddress(z2)); + } + + + } + private static long getDeviceAddress(INDArray array){ + if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){ + throw new IllegalStateException("Cannot ge device pointer for non-CUDA device"); + } + + //Use reflection here as OpaqueDataBuffer is only available on BaseCudaDataBuffer and BaseCpuDataBuffer - not DataBuffer/BaseDataBuffer + // due to it being defined in nd4j-native-api, not nd4j-api + try { + Class c = Class.forName("org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer"); + Method m = c.getMethod("getOpaqueDataBuffer"); + OpaqueDataBuffer db = (OpaqueDataBuffer) m.invoke(array.data()); + long address = db.specialBuffer().address(); + return address; + } catch (Throwable t){ + throw new RuntimeException("Error getting OpaqueDataBuffer", t); + } + } + + + + +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java new file mode 100644 index 000000000..99a050f63 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java @@ -0,0 +1,96 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +import org.eclipse.python4j.PythonException; +import org.eclipse.python4j.PythonObject; +import org.eclipse.python4j.PythonTypes; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4j; + +import javax.annotation.concurrent.NotThreadSafe; +import java.util.*; + + +@NotThreadSafe +@RunWith(Parameterized.class) +public class PythonNumpyCollectionsTest { + private DataType dataType; + + public PythonNumpyCollectionsTest(DataType dataType){ + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}") + public static DataType[] params() { + return new DataType[]{ + DataType.BOOL, + DataType.FLOAT16, + //DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, + DataType.INT8, + DataType.INT16, + DataType.INT32, + DataType.INT64, + DataType.UINT8, + DataType.UINT16, + DataType.UINT32, + DataType.UINT64 + }; + } + @Test + public void testPythonDictFromMap() throws PythonException { + Map map = new HashMap(); + map.put("a", 1); + map.put(1, "a"); + map.put("arr", Nd4j.ones(dataType, 2, 3)); + map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType,3,2))); + Map innerMap = new HashMap(); + innerMap.put("b", 2); + innerMap.put(2, "b"); + innerMap.put(5, Nd4j.ones(dataType, 5)); + map.put("innermap", innerMap); + map.put("list2", Arrays.asList(4, "5", innerMap, false, true)); + PythonObject dict = PythonTypes.convert(map); + Map map2 = PythonTypes.DICT.toJava(dict); + Assert.assertEquals(map.toString(), map2.toString()); + } + + @Test + public void testPythonListFromList() throws PythonException{ + List list = new ArrayList<>(); + list.add(1); + list.add("2"); + list.add(Nd4j.ones(dataType, 2, 3)); + list.add(Arrays.asList("a", + Nd4j.ones(dataType, 1, 2),1.0, 2f, 10, true, false, + Nd4j.zeros(dataType, 3, 2))); + Map map = new HashMap(); + map.put("a", 1); + map.put(1, "a"); + map.put(5, Nd4j.ones(dataType,4, 5)); + map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType, 3, 1))); + list.add(map); + PythonObject dict = PythonTypes.convert(list); + List list2 = PythonTypes.LIST.toJava(dict); + Assert.assertEquals(list.toString(), list2.toString()); + } +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java new file mode 100644 index 000000000..d1c5ba761 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java @@ -0,0 +1,55 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import org.eclipse.python4j.Python; +import org.eclipse.python4j.PythonGC; +import org.eclipse.python4j.PythonObject; +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.factory.Nd4j; + +import javax.annotation.concurrent.NotThreadSafe; + + +@NotThreadSafe +public class PythonNumpyGCTest { + + @Test + public void testGC(){ + PythonObject gcModule = Python.importModule("gc"); + PythonObject getObjects = gcModule.attr("get_objects"); + PythonObject pyObjCount1 = Python.len(getObjects.call()); + long objCount1 = pyObjCount1.toLong(); + PythonObject pyList = Python.list(); + pyList.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10))); + pyList.attr("append").call(1.0); + pyList.attr("append").call(true); + PythonObject pyObjCount2 = Python.len(getObjects.call()); + long objCount2 = pyObjCount2.toLong(); + long diff = objCount2 - objCount1; + Assert.assertTrue(diff > 2); + try(PythonGC gc = PythonGC.watch()){ + PythonObject pyList2 = Python.list(); + pyList2.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10))); + pyList2.attr("append").call(1.0); + pyList2.attr("append").call(true); + } + PythonObject pyObjCount3 = Python.len(getObjects.call()); + long objCount3 = pyObjCount3.toLong(); + diff = objCount3 - objCount2; + Assert.assertTrue(diff <= 2);// 2 objects created during function call + } +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java new file mode 100644 index 000000000..580f8643b --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java @@ -0,0 +1,22 @@ +import org.eclipse.python4j.NumpyArray; +import org.eclipse.python4j.Python; +import org.eclipse.python4j.PythonGC; +import org.eclipse.python4j.PythonObject; +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +public class PythonNumpyImportTest { + + @Test + public void testNumpyImport(){ + try(PythonGC gc = PythonGC.watch()){ + PythonObject np = Python.importModule("numpy"); + PythonObject zeros = np.attr("zeros").call(5); + INDArray arr = NumpyArray.INSTANCE.toJava(zeros); + Assert.assertEquals(arr, Nd4j.zeros(DataType.DOUBLE, 5)); + } + } +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java new file mode 100644 index 000000000..399b87fb1 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java @@ -0,0 +1,303 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import org.eclipse.python4j.*; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + + +@javax.annotation.concurrent.NotThreadSafe +@RunWith(Parameterized.class) +public class PythonNumpyJobTest { + private DataType dataType; + + public PythonNumpyJobTest(DataType dataType){ + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}") + public static DataType[] params() { + return new DataType[]{ + DataType.BOOL, + DataType.FLOAT16, + DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, + DataType.INT8, + DataType.INT16, + DataType.INT32, + DataType.INT64, + DataType.UINT8, + DataType.UINT16, + DataType.UINT32, + DataType.UINT64 + }; + } + + @Test + public void testNumpyJobBasic(){ + PythonContextManager.deleteNonMainContexts(); + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, 2, 3); + INDArray y = Nd4j.zeros(dataType, 2, 3); + INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2)); + z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z; + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + List outputs = new ArrayList<>(); + PythonVariable output = new PythonVariable<>("z", arrType); + outputs.add(output); + String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)"; + + PythonJob job = new PythonJob("job1", code, false); + + job.exec(inputs, outputs); + + INDArray z2 = output.getValue(); + + if (dataType == DataType.BFLOAT16){ + z2 = z2.castTo(DataType.FLOAT); + } + + Assert.assertEquals(z, z2); + + } + + @Test + public void testNumpyJobReturnAllVariables(){ + PythonContextManager.deleteNonMainContexts(); + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, 2, 3); + INDArray y = Nd4j.zeros(dataType, 2, 3); + INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2)); + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)"; + + PythonJob job = new PythonJob("job1", code, false); + List outputs = job.execAndReturnAllVariables(inputs); + + INDArray x2 = (INDArray) outputs.get(0).getValue(); + INDArray y2 = (INDArray) outputs.get(1).getValue(); + INDArray z2 = (INDArray) outputs.get(2).getValue(); + + if (dataType == DataType.BFLOAT16){ + x = x.castTo(DataType.FLOAT); + y = y.castTo(DataType.FLOAT); + z = z.castTo(DataType.FLOAT); + } + Assert.assertEquals(x, x2); + Assert.assertEquals(y, y2); + Assert.assertEquals(z, z2); + + } + + + @Test + public void testMultipleNumpyJobsParallel(){ + PythonContextManager.deleteNonMainContexts(); + String code1 =(dataType == DataType.BOOL)?"z = x":"z = x + y"; + PythonJob job1 = new PythonJob("job1", code1, false); + + String code2 =(dataType == DataType.BOOL)?"z = y":"z = x - y"; + PythonJob job2 = new PythonJob("job2", code2, false); + + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, 2, 3); + INDArray y = Nd4j.zeros(dataType, 2, 3); + INDArray z1 = (dataType == DataType.BOOL)?x:x.add(y); + z1 = (dataType == DataType.BFLOAT16)? z1.castTo(DataType.FLOAT): z1; + INDArray z2 = (dataType == DataType.BOOL)?y:x.sub(y); + z2 = (dataType == DataType.BFLOAT16)? z2.castTo(DataType.FLOAT): z2; + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + + + List outputs = new ArrayList<>(); + + outputs.add(new PythonVariable<>("z", arrType)); + + job1.exec(inputs, outputs); + + assertEquals(z1, outputs.get(0).getValue()); + + + job2.exec(inputs, outputs); + + assertEquals(z2, outputs.get(0).getValue()); + + } + + + @Test + public synchronized void testNumpyJobSetupRun(){ + if (dataType == DataType.BOOL)return; + PythonContextManager.deleteNonMainContexts(); + String code = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n"; + + PythonJob job = new PythonJob("job1", code, true); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); + job.exec(inputs, outputs); + + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10), + outputs.get(0).getValue()); + + + inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + + + outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); + + job.exec(inputs, outputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12), + outputs.get(0).getValue()); + + + } + @Test + public void testNumpyJobSetupRunAndReturnAllVariables(){ + if (dataType == DataType.BOOL)return; + PythonContextManager.deleteNonMainContexts(); + String code = "five=None\n" + + "c=None\n"+ + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " global c\n" + + " c = a + b + five\n"; + PythonJob job = new PythonJob("job1", code, true); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + + List outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10), + outputs.get(1).getValue()); + + + inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + + + outputs = job.execAndReturnAllVariables(inputs); + + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12), + outputs.get(1).getValue()); + + + } + + @Test + public void testMultipleNumpyJobsSetupRunParallel(){ + if (dataType == DataType.BOOL)return; + PythonContextManager.deleteNonMainContexts(); + + String code1 = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n"; + PythonJob job1 = new PythonJob("job1", code1, true); + + String code2 = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b - five\n"+ + " return {'c':c}\n\n"; + PythonJob job2 = new PythonJob("job2", code2, true); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + + + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); + + job1.exec(inputs, outputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10), + outputs.get(0).getValue()); + + + job2.exec(inputs, outputs); + + assertEquals(Nd4j.zeros((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3), + outputs.get(0).getValue()); + + + inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + + outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); + + + job1.exec(inputs, outputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12), + outputs.get(0).getValue()); + + + job2.exec(inputs, outputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(2), + outputs.get(0).getValue()); + + + } + +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java new file mode 100644 index 000000000..52ccd1fd0 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java @@ -0,0 +1,194 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import org.eclipse.python4j.*; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import javax.annotation.concurrent.NotThreadSafe; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + + +@NotThreadSafe +@RunWith(Parameterized.class) +public class PythonNumpyMultiThreadTest { + private DataType dataType; + + public PythonNumpyMultiThreadTest(DataType dataType) { + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}") + public static DataType[] params() { + return new DataType[]{ +// DataType.BOOL, +// DataType.FLOAT16, +// DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, +// DataType.INT8, +// DataType.INT16, + DataType.INT32, + DataType.INT64, +// DataType.UINT8, +// DataType.UINT16, +// DataType.UINT32, +// DataType.UINT64 + }; + } + + + @Test + public void testMultiThreading1() throws Throwable { + final List exceptions = Collections.synchronizedList(new ArrayList()); + Runnable runnable = new Runnable() { + @Override + public void run() { + try (PythonGIL gil = PythonGIL.lock()) { + try (PythonGC gc = PythonGC.watch()) { + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + PythonVariable out = new PythonVariable<>("z", NumpyArray.INSTANCE); + String code = "z = x + y"; + PythonExecutioner.exec(code, inputs, Collections.singletonList(out)); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), out.getValue()); + } + } catch (Throwable e) { + exceptions.add(e); + } + } + }; + + int numThreads = 10; + Thread[] threads = new Thread[numThreads]; + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(runnable); + } + for (int i = 0; i < threads.length; i++) { + threads[i].start(); + } + Thread.sleep(100); + for (int i = 0; i < threads.length; i++) { + threads[i].join(); + } + if (!exceptions.isEmpty()) { + throw (exceptions.get(0)); + } + + } + + @Test + public void testMultiThreading2() throws Throwable { + final List exceptions = Collections.synchronizedList(new ArrayList()); + Runnable runnable = new Runnable() { + @Override + public void run() { + try (PythonGIL gil = PythonGIL.lock()) { + try (PythonGC gc = PythonGC.watch()) { + PythonContextManager.reset(); + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + String code = "z = x + y"; + List outputs = PythonExecutioner.execAndReturnAllVariables(code, inputs); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(3), outputs.get(0).getValue()); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(4), outputs.get(1).getValue()); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), outputs.get(2).getValue()); + } + } catch (Throwable e) { + exceptions.add(e); + } + } + }; + + int numThreads = 10; + Thread[] threads = new Thread[numThreads]; + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(runnable); + } + for (int i = 0; i < threads.length; i++) { + threads[i].start(); + } + Thread.sleep(100); + for (int i = 0; i < threads.length; i++) { + threads[i].join(); + } + if (!exceptions.isEmpty()) { + throw (exceptions.get(0)); + } + } + + @Test + public void testMultiThreading3() throws Throwable { + PythonContextManager.deleteNonMainContexts(); + + String code = "c = a + b"; + final PythonJob job = new PythonJob("job1", code, false); + + final List exceptions = Collections.synchronizedList(new ArrayList()); + + class JobThread extends Thread { + private INDArray a, b, c; + + public JobThread(INDArray a, INDArray b, INDArray c) { + this.a = a; + this.b = b; + this.c = c; + } + + @Override + public void run() { + try { + PythonVariable out = new PythonVariable<>("c", NumpyArray.INSTANCE); + job.exec(Arrays.asList(new PythonVariable<>("a", NumpyArray.INSTANCE, a), + new PythonVariable<>("b", NumpyArray.INSTANCE, b)), + Collections.singletonList(out)); + Assert.assertEquals(c, out.getValue()); + } catch (Exception e) { + exceptions.add(e); + } + + } + } + int numThreads = 10; + JobThread[] threads = new JobThread[numThreads]; + for (int i = 0; i < threads.length; i++) { + threads[i] = new JobThread(Nd4j.zeros(dataType, 2, 3).add(i), Nd4j.zeros(dataType, 2, 3).add(i + 3), + Nd4j.zeros(dataType, 2, 3).add(2 * i + 3)); + } + + for (int i = 0; i < threads.length; i++) { + threads[i].start(); + } + Thread.sleep(100); + for (int i = 0; i < threads.length; i++) { + threads[i].join(); + } + + if (!exceptions.isEmpty()) { + throw (exceptions.get(0)); + } + } +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java new file mode 100644 index 000000000..d3c649c8d --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +import org.eclipse.python4j.*; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import javax.annotation.concurrent.NotThreadSafe; +import java.util.ArrayList; +import java.util.List; + +@NotThreadSafe +public class PythonNumpyServiceLoaderTest { + + @Test + public void testServiceLoader(){ + Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.get("numpy.ndarray")); + Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.getPythonTypeForJavaObject(Nd4j.zeros(1))); + } + + +} From 08a77d929b229813abdcb82744d41ac6b8c4c707 Mon Sep 17 00:00:00 2001 From: raver119 Date: Tue, 16 Jun 2020 09:07:20 +0300 Subject: [PATCH 05/11] few minor tweaks for recent MSVC update Signed-off-by: raver119 --- libnd4j/include/cnpy/cnpy.h | 2 +- libnd4j/include/ops/ops.h | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/libnd4j/include/cnpy/cnpy.h b/libnd4j/include/cnpy/cnpy.h index ea847c3e7..c84623599 100644 --- a/libnd4j/include/cnpy/cnpy.h +++ b/libnd4j/include/cnpy/cnpy.h @@ -69,7 +69,7 @@ namespace cnpy { } }; - struct ND4J_EXPORT npz_t : public std::unordered_map { + struct ND4J_EXPORT npz_t : public std::map { void destruct() { npz_t::iterator it = this->begin(); for(; it != this->end(); ++it) (*it).second.destruct(); diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index ea52e9ba0..aca6fec6f 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -3963,9 +3963,6 @@ namespace simdOps { } #endif -#ifndef __clang__ -#pragma omp declare simd uniform(extraParamsRef) -#endif op_def static Y merge(X old, X opOutput, X *extraParamsRef) { return update(old, opOutput, extraParamsRef); } From de625baea68e35b892c4045d237a0f5e44343d6b Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 17 Jun 2020 09:41:42 +0300 Subject: [PATCH 06/11] nested while graph Signed-off-by: raver119 --- .../tests_cpu/resources/simplewhile_nested.fb | Bin 0 -> 26808 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 libnd4j/tests_cpu/resources/simplewhile_nested.fb diff --git a/libnd4j/tests_cpu/resources/simplewhile_nested.fb b/libnd4j/tests_cpu/resources/simplewhile_nested.fb new file mode 100644 index 0000000000000000000000000000000000000000..9404b98c4c2dd30c5a37448879cd984042a146b3 GIT binary patch literal 26808 zcmchg4RBo7RmZPv*;ZV&6|UU4j*@I#C#a6vD2|$@8PRSXB{&&1BN7peD`Lr(Qme?e zkY%i7JaoxuI+~28qbVB37%_;9rX!3JG(iN{m?B&Q!k8dZ0mc|&f(gciAdCqngf74T zd+%!BUcX(fz0bKb|2_BP+;cz9x#!;d_GwjRY*=M(BU)<8Ou4Brm1d1uMY^)Ym=@?| zXce>utOv$?1g{|aW^$49xx{?SU>(11@iY2lN4jb%)bw$bP2S41=>^4$OdK zUcXKzPOTW{x{-NvYlsW!KO{X09m zI_~j5ABXQz%Z3fvP$-|}w^Z$@kD}%JAS?xwfT{EP@cwu@og`idWcRfo0W_v^pnlpA zh_J~8iW$l7yV`biJk-|JWtx87mnm7*Kys@)fltQRLM_uUFnp)P1I> zZCA(RfOyX@r_-9abF`S=S&D() zb7|a|LU?SOma1)BMP7Zsjjbo)(U_$d&1nKuf^u*@`cD2vy!swOUYz>g{wwLU#<}uk zU*Go;mp)H{6W|D#0AoP9mM!!H>AM9Q)^|7QCSXG$`u-Tq5i8t^FpHHhFa{B|c|Q=% zwXWx|?Ih)K>bd%*c=gVd7Z~A&>tNcLy`6p;Er$496hC=x)zopW* z_NaV%KZT7a;fYi43FcEWwT9yG=iSJQQ}3I8EuH=heOvOXr}x$J?efc7*OOor3sP5|{+L23CNwKXyszdq*!nQ`g!_>1YZ_N$Z7PeWh$9bXTUe*hGsher6sAJ)Uo z4G(EW@lbXt4K6VMvydsDt2luzsjvaw@e5{!WnFa!pG^rZFB3{`&;pc3o^9HAG2 z)rD-cqpved4;AEv{olhuxvjidx#hD|ZOMm==Gy0u{Z{;XsbJocw5fTJZ_r%IE;S$J zKx1Dlh{v8^Tai~!`g-NL=huzjPN(li-le1dI+O2{4rjp(I0hzxbbJsDfqo$U%D<@} z8czl220t?Qp;I6Gg`fYZ)omGho$VdFdpmodCAzTAH2$!D9}V${PoWSVpELVYZOd-rN3VGtDjmACqQ_ujF27z z(%At}tUu&&OO{^b55qcHteomuLGveHE}F|9o;ni0J`&%F*B^ExuNeKh{%{lbMd|qB z@A~_l*5MeK2N%F;FaxZv$B9eV(r+&`Y_rn+ZNLWkpk*8ac|2zJ?(gi~vC}^%tS4Ox zWar_v_oT=3p!e{`)vsShg&-_eZu!Dg{gGes&b!bqjG;AHSn5wH@kwZl;@~VcpM=J_ zK5Ln0$<(}61Ffej5VhvXjtuIX-5*oPt0#TE@?3kEjX_>_Pe*&_j<((o`t>6B)K5VV zL)Sx}_^xkfvN_q(6)+FZfmv_@908MH49L#=6d8JU- zo*LnK$Y-I{4~6hpUQ5-c-HUVWUxU9Nzx|vGWNJR*u%FuL;`JSV9=(Knp!#+EkNk6^ z;?nahm;v%7*P`cd{L(F7In@5-vsTY5so&MHcki~2pV`;e%me*3XD_fPnI04ju1*Kbn+V1^>`1L#!$c#(Rhqznv)tY$`dM9-2-+202uDk<4 zx|s%(U<{0aA&dg3e(7`Q18D~nFQ~;RI>)O8tVA2v)&}dN@y^fz2~n)_sg$LgEAQt zn1`GPr@<+30?5{7(?d|%w`{oyRDe-X4-{l)37|FP?q!@a_sWfAz7zF#LNA|PP|a&? z2)jCZzS;p*UX-ja2jvud!EMVew|rqL*>?ir5K%f5$f8UWWl>SG-K`;Og0O@rnXa&tcwxa%MY*j#Z zl2m+=@86cf6aLQ4O;el|q{H<^`%BduD>X zjnK8wGp~8`y?-6?eqh((6!CE|0)~K{f6ce%UHwUbN+5km2R3M~w6+9yeYtDWTR)+b z>0;%!jWAVvvKQ$h-})N+v-tIMC6F1Hem0{I&E@sVbJy3}>2zA(_pgKA4m~;R>-!Gk z()TQw0mr~37z0CKF@3kgV?$WqeB;y6LxJ*2f$J|H4$5ua#mdDeS$ZeQ4a%3@iVo<5 zpf;bbm%iVV59|lCU$UAw5@rm z23iOC)&Rv>p_G*Mdv)kU^Xb~tjrMyJw401*_gd`O^?T2ASJXUgI_uflYWcukAln)T zuHWk>Ph*#zH9$3&s$T(`6kp{0D%#-lId401s^d5I!eFwVaPts|<<9^tGzPev)J_U|}DIgt;fMGBI zQb79A{A<3|p9H7`t)S5NmE`yR`F6jV_xFTcWmL-w(uZtDG}ryG{dD~LI1tz3Y1D4S!u#5$^{xU>byVJ49N#JpiQFW~ln1Iu&3iDD-)x_#gBT{Q5nJ%wqH*KVa9|74~WALjB+H557LM=2Bo9 z$cIdTQ7{a0tuw2SR(J~4$4B7Jo^LXOkgv@0gsJ+kKIGEFslSb14+-{#xb%=>|I!|* z{vY~#|6HTH`~e`l3hQ8ybPDu=E}(VM2-UhsfJ(3h6zcD^?`3=%i6G?4x8;_q&$1cO zTz;;RXW&W7>p%?%_qCj~Y|;kFi3;w%GaAU0tZE>+vH^|Xl}T1)2(N{CCzlv*VH9>I2$K6y#r`-aSf?ukncA*Jt=` z5dUw^vh!gKLT0_$I<_R5st+2EY*g_sl-3$3PJM9SH)?$22IxPK8Jqr*7pMM{=Zcq? z_5KHUwVO5PQ0e&io4#GGAuiiF2Tp^<>`L>M09q$uyCRv>wzRIKSJzJTo{`pC*se$} zgnVTSOreNuNn_NQ3bCs(_N*klar%`jfy}tH@6KgMq_)%f=$ZN_CUeaMT`ULNDlr+K;bmZ#U%iXQ++z$6$2(y8{I6p&Bo z0@7IsZB8evM7*=rouGc8n9m4}u}k52Odp zdsq+Mq}zcFnyXyTt+bAP|LlMAbm1QN*;FAsYD?{WP7x<4Ut<#;)CWQOai7~a|8xBM z8b@X^>qY&v_o2(k(!ADz8er#q4e@dyUIh)Y%8NyRh(A2%;katJXT>gTqJez3UCn8@ zUDZ|F!4xZ*6|}a(&+gy7=>P4PzT@8jz1VmB`j%hc;S!dq553?Bm;|F>7z~0G zP?_dTcB(N{1Jx-5%|LyXZRUF3pYaFTPyO9|$)Sn7&;9j`R{mkJ=l*Be<1~gkVD&70 zs^9K>ljV&;Wo{Vuo(`ZxvMMlcWFUpHHOg55j-C@47DdPujPk5FYVLNB7%GA*y^m^Azn1 z6j@4K-+fO&?S41>cJcG;-lH!dGcNyGhi;nbuWPSS{!@8TLF;cZe>p-sN%-7$a;=MC zO0$}CePz!Z>9o$9Yv+CYl0H&k226t~Fb+n5>@Ed*f&5+*)b17PLn~+pHmLnv`*Y^q zysIrMUzub}ZZ3P$nM39I?CJbF*EWxcyDGT$^uWKzXHWgeYNo|HVEuH0c-;2%Hti&l z>$W3Zy8D6VRB(f9Psh1O%J+Wll5bDaM?aVavZ-T0`{@`M045SY}vOf_YE&&@|9WH zVXAiI+jLgT_kLIZe_q@Ch1;-k*N#pHGHYR|2J$OaAWA1>M}|^z?VX_-JrrY??p@$L z?w_*DSKdvh)pwoUw_NA@xtimH;53*8(_jjWgArhDL_SA8rV-e&Xg<1wG{g(io5rqw z`L^J{nVZjUGQ6JMcyjcSD}R9@6p?)G&5PN~)PKcqFJIldnIbM0jV zcgOqa-$tnRpiA$2eqpudD+Q*(5m1Oy=={DneT4^{sYzaF}gS=o&Qazm6q>p3Vi40YkTgoTH{ZFX>cR; zCi}4KCtGX&TF6yKwX1^kCBK?${mi`=zuu~u>m+T)wSL~?KDrcnFaDo@zR`G6Ulaew z^?@!Rdu)V;b-9Ifq4tX67L-I~+!q=Fsim8uG&Sp&?d*TtU2x*@W%3_LU!QpJYz^ zbZzq@cTx5CRnCQYrzMh80c1bRxj%lE_+r2HdnuysChDr6nP-z%|E}x?&OUCN&rwh3 zNj+cn=P0Y5c{-xb%hb{L1$rkWJg?a`dABK|?nUbI8bRdmC!UPpU(5YXYfAR)x8?u7 z-@6gKkHC8me45X2zwB?&ZMYIKe(HP%e(5w^$G$a4w%Hz0?-lBOnzE1QsHe9L4@K0w zNWB{B=y_wfJ^S|QYoR*olk8%WJ@;zdr@JZTJ-ROS$?_j>-)ho%e{UzMxJ{3_f{FcDQhf-TSi|{vC^ToP*qX>7kzf=3LL|!g}EEsJ%`;6;bEg)VT|O?JM?g-G45E?=|@D zgh%%wN49x6g8v8buZP#|i@ia=8Bu2=KH!to(b-9FkT1tZ+%Y}{&pOI8YuoRQeWLSV zM4cC?b4Ns9?5?ysqRw}zlc0|F6#K2-FGlda1>f!Pxc#xGtdkLS*0Rp^E?j2)YqZ9^ z8^QYsyth)8@%MIB8FQC+PtNwYKMTM1jf}srtKPTMtr2w&Q%C+;_icYRZ5{r0DuVwE z{CaOr`+(2yd!tuFyms9kWUbrxb2EJkuQ+4g3DvV}?hE(~)ywqB*5N%u?_9lXiQ3Ru zZUvcYdeQyowhWDt$v}N~Se>~QcFn*6edtIM_U+)`afBEq@n`Nbc4f0Kp=Z}0JOjV_lQ~oR9%fSn z@3-KU&ctu;r1p7FPXzyQ_|=ar{@~feixK>n;Fn#gj+Jlk!>>p1uVcNd-JiG*{|>icgJU zJ#gncV;}z77zlI}=RChd-(^FYw)|e-=FRW2k|xTUfpnq%l?VA+>#_;qiz%NH^+UoZ zuH7m3X~rPCll|MZwv7Y)H*>UiAfmk_b=B@dNcgd+_Fkhs@n*Hhexg(88b>Sgk*tc6 z7Cu{n=9%X44f?M;TGcDO{aifZ@qZ~%cz>Qop3X^{Pdk31=BPi;rPah`%XYrCrp2#d wsp9eme!K417W$=Usy^pVL3eh^9s#pMt-gNe>n6!$^ZZW literal 0 HcmV?d00001 From b4a4a78f21925a37900897e14e097524e8ced17d Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 19 Jun 2020 20:50:57 +0300 Subject: [PATCH 07/11] One more graph for tests Signed-off-by: raver119 --- libnd4j/tests_cpu/resources/while_iter3.fb | Bin 0 -> 9512 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 libnd4j/tests_cpu/resources/while_iter3.fb diff --git a/libnd4j/tests_cpu/resources/while_iter3.fb b/libnd4j/tests_cpu/resources/while_iter3.fb new file mode 100644 index 0000000000000000000000000000000000000000..4b0e86979ef77dd11422c6701202f8a5b9b2bc60 GIT binary patch literal 9512 zcmchdUuYc19mhwCY$?vQQJr%Tj%crG9z-gz$+*KpYP6& zcJI6=Dbl7t_}uKy{N^`*e!rPn?W)^2Ich%3Wt-_R3Daqkrpp}Vehvm;3%a2fo&aM$ zjaG+ub=b6XCwZWy=8JdBPbaitce2s88a@2sB{)0bfjuEPeb!U`;bd|d(gd5kpd z=UMKD!HMQ*H$N+OBm6x2lruuQ7oCv0cd51&BTAe3`^PqNH)% zW__Eo700?g{p(sybMrK*&VrxZvg=LrScY}bI+G6;VIE$EG894culWv-ZH)U1;Kcbr zyP~4dk;IGc956P z<|d2z;wxNk`&IhZchCE!^}gM$T(%lTZK|G9#ib;?-pY(IOg=J9LWsolmq(jgTz0e8TN5lS9U$x(C?~5m1@%pH4W-rgILl0ssH!>V)xNcaa{t9 zX9bpE9x6}<`Je#u(GaQnk%AVJ5R&0+5#{%m{>y}*7fpPu&pnc#W>Zo8p0Bhf^_iyv0_QG-1%D$&LKSIi@7<23IHm;OchIP;xbQKmsXV9xq zh9YRc8Yb0zDNZ_J44wmtwgsX6v2gL?)adxd$*iAOA5%k=K4S|hmjv;4f0@+?eq9?% zo1H@wj4e%hA83D7pHq+o#j5(_M0TP?u^E0JvhLfItsanF%|nz;vLi+9x@Wxv>w19u z{n~3hhu&rHRo)AC?Q^`YOQ5r-3QI5#6)1!DgaXK~L!{w3I?w%$FIT_RwfUHewtcJq z7jiS+`y9=i?iD|FUl{UB$UvIq+tT1ycC^oFAGikUuSEAFpf>vd5x+kxzRms4Ki%^B z=KT5MU+sSV9a62YU;NvyFO97X8rwRoz!Jz8&DK|nGR4&xC`O!Uu6QKycqupiZJig+ z4?3f*AHJ@Vx<96NDx=zcBAXh6e9*;Bvvv0NKjZhy6aTU{kZDU=eQQWUT)P~x1>+T%(;`Hy6?8|!pdWuwQ?aF_xe>Goa z*n|y`kFUZa>}4Hi(Q~4?a_g8kBe`krm8Wq9YxP-`G+D>@mr>QTp?#xKTzg9A^n<9a z6xZ61M*RDwd-41G4mJ+XLzF8?cI*v?)dL^ zE-8LWumL(V<;NwMhcfK8-(=BoqP{na$6#OaV$M|wy$7n#mxAnRjv7Vz8n@Dleh}TK zxDLl*5A%~IzYlsLj(w+u&A8%l6I=h!IPCxElZiu}0!Q$GI|~$#!~VO*j~S20lD0;$8Ei`~pA~0`VE?QCM*TrVozv9O^UhA+GM{8UG`2NOJ$-MHU+i9M zzoUIKLib1Ls*mcc)eXM&d_O|}E%fysSaqy^{WpZyBkHs=HvOF${H^Kg@V9V;{wefz z76*UFs87Z`%35?|m5(&93!uJh-h)25_iy&S+}Th)=Z_!KhUx`*Me>bhW;NXPw}0mX{;+-#}l_*TGrj>6vebjyzjFGj-x0}~|yTg>71NlV# zDMhqLk=L$LPf|9CYHyPI+E0V_yk{SBjxr1Nv!cC4bCTtE0yIbJgZ@{^YPNeZE1mH8 W)9JnUN9SAzy_YB5JnIbM*8C4P&>=qn literal 0 HcmV?d00001 From e9c13ca9f4ca1f6cecc36cc1445ff40d57556724 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Thu, 25 Jun 2020 07:05:08 +0400 Subject: [PATCH 08/11] Python4j change namespace (org.eclipse->org.nd4j) (#493) * org.eclipse->org.nd4j * Fix parent groupid Signed-off-by: Alex Black * memview->bytes * del test Co-authored-by: Alex Black --- python4j/pom.xml | 2 +- python4j/python4j-core/pom.xml | 2 +- .../{eclipse => nd4j}/python4j/Python.java | 2 +- .../python4j/PythonContextManager.java | 15 +- .../python4j/PythonException.java | 2 +- .../python4j/PythonExecutioner.java | 3 +- .../{eclipse => nd4j}/python4j/PythonGC.java | 2 +- .../{eclipse => nd4j}/python4j/PythonGIL.java | 3 +- .../{eclipse => nd4j}/python4j/PythonJob.java | 3 +- .../python4j/PythonObject.java | 2 +- .../python4j/PythonProcess.java | 2 +- .../python4j/PythonType.java | 2 +- .../python4j/PythonTypes.java | 176 +++++++++++------- .../python4j/PythonVariable.java | 2 +- .../python4j/PythonVariables.java | 2 +- .../org/nd4j/python4j/pythonexec/__init__.py | 0 .../python4j/pythonexec/pythonexec.py | 0 .../test/java/PythonBasicExecutionTest.java | 5 +- .../src/test/java/PythonBufferTest.java | 113 ----------- .../src/test/java/PythonCollectionsTest.java | 6 +- .../test/java/PythonContextManagerTest.java | 6 +- .../src/test/java/PythonGCTest.java | 6 +- .../src/test/java/PythonJobTest.java | 8 +- .../src/test/java/PythonMultiThreadTest.java | 3 +- .../test/java/PythonPrimitiveTypesTest.java | 20 +- python4j/python4j-numpy/pom.xml | 4 +- .../python4j/NumpyArray.java | 6 +- .../services/org.eclipse.python4j.PythonType | 1 - .../services/org.nd4j.python4j.PythonType | 1 + .../src/test/java/PythonNumpyBasicTest.java | 3 +- .../test/java/PythonNumpyCollectionsTest.java | 6 +- .../src/test/java/PythonNumpyGCTest.java | 6 +- .../src/test/java/PythonNumpyImportTest.java | 8 +- .../src/test/java/PythonNumpyJobTest.java | 2 +- .../test/java/PythonNumpyMultiThreadTest.java | 2 +- .../java/PythonNumpyServiceLoaderTest.java | 10 +- 36 files changed, 190 insertions(+), 246 deletions(-) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/Python.java (99%) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/PythonContextManager.java (95%) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/PythonException.java (98%) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/PythonExecutioner.java (99%) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/PythonGC.java (99%) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/PythonGIL.java (97%) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/PythonJob.java (99%) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/PythonObject.java (99%) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/PythonProcess.java (99%) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/PythonType.java (98%) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/PythonTypes.java (74%) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/PythonVariable.java (98%) rename python4j/python4j-core/src/main/java/org/{eclipse => nd4j}/python4j/PythonVariables.java (98%) create mode 100644 python4j/python4j-core/src/main/resources/org/nd4j/python4j/pythonexec/__init__.py rename python4j/python4j-core/src/main/resources/org/{eclipse => nd4j}/python4j/pythonexec/pythonexec.py (100%) delete mode 100644 python4j/python4j-core/src/test/java/PythonBufferTest.java rename python4j/python4j-numpy/src/main/java/org/{eclipse => nd4j}/python4j/NumpyArray.java (97%) delete mode 100644 python4j/python4j-numpy/src/main/resources/META-INF/services/org.eclipse.python4j.PythonType create mode 100644 python4j/python4j-numpy/src/main/resources/META-INF/services/org.nd4j.python4j.PythonType diff --git a/python4j/pom.xml b/python4j/pom.xml index 1fe50344f..3f1d026a5 100644 --- a/python4j/pom.xml +++ b/python4j/pom.xml @@ -25,7 +25,7 @@ 4.0.0 - org.eclipse + org.nd4j python4j-parent pom diff --git a/python4j/python4j-core/pom.xml b/python4j/python4j-core/pom.xml index e74d32392..26e77b8d1 100644 --- a/python4j/python4j-core/pom.xml +++ b/python4j/python4j-core/pom.xml @@ -21,7 +21,7 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> python4j-parent - org.eclipse + org.nd4j 1.0.0-SNAPSHOT jar diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/Python.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/Python.java similarity index 99% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/Python.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/Python.java index fd6fff112..03c2fdaab 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/Python.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/Python.java @@ -15,7 +15,7 @@ ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.bytedeco.cpython.PyObject; diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java similarity index 95% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java index 5675d0864..0090e38d4 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonContextManager.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import javax.lang.model.SourceVersion; @@ -103,7 +103,18 @@ public class PythonContextManager { } private static boolean validateContextName(String s) { - return SourceVersion.isIdentifier(s) && !s.startsWith(COLLAPSED_KEY); + for (int i=0; i= '0' && c <= '9'){ + return false; + } + } + if (!(c=='_' || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9'))){ + return false; + } + } + return true; } private static String getContextPrefix(String contextName) { diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonException.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonException.java similarity index 98% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonException.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonException.java index a9bbf596c..e8f64f2be 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonException.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonException.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; /** diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java similarity index 99% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java index 542778f76..bc48b0e98 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonExecutioner.java @@ -15,7 +15,7 @@ ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.bytedeco.cpython.PyObject; @@ -25,7 +25,6 @@ import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGC.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGC.java similarity index 99% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGC.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGC.java index 5531b67d3..e18d2072d 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGC.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGC.java @@ -15,7 +15,7 @@ ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.bytedeco.cpython.PyObject; import org.bytedeco.javacpp.Pointer; diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGIL.java similarity index 97% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGIL.java index 074be294a..3a88253e0 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonGIL.java @@ -14,11 +14,10 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.bytedeco.cpython.PyThreadState; -import org.omg.SendingContext.RunTime; import java.util.concurrent.atomic.AtomicBoolean; diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonJob.java similarity index 99% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonJob.java index 0818de890..f357388f7 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonJob.java @@ -14,12 +14,11 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import lombok.Builder; import lombok.Data; -import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j; import javax.annotation.Nonnull; diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java similarity index 99% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java index 69252a5f7..94b60d320 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonObject.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.bytedeco.cpython.PyObject; diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java similarity index 99% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java index 0ca17fb49..bce8809f5 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonProcess.java @@ -15,7 +15,7 @@ ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.apache.commons.io.IOUtils; import org.bytedeco.javacpp.Loader; diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonType.java similarity index 98% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonType.java index 47b725cd5..79b0ccaab 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonType.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import java.io.File; diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java similarity index 74% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java index cd7ac7d7c..089c8aefe 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java @@ -14,14 +14,12 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import org.bytedeco.cpython.PyObject; import org.bytedeco.javacpp.BytePointer; -import org.bytedeco.javacpp.Loader; import org.bytedeco.javacpp.Pointer; -import sun.misc.Unsafe; import sun.nio.ch.DirectBuffer; import java.lang.reflect.Field; @@ -37,7 +35,7 @@ public class PythonTypes { private static List getPrimitiveTypes() { - return Arrays.asList(STR, INT, FLOAT, BOOL, MEMORYVIEW); + return Arrays.asList(STR, INT, FLOAT, BOOL, BYTES); } private static List getCollectionTypes() { @@ -258,6 +256,10 @@ public class PythonTypes { int[] arr = (int[]) javaObject; for (int x : arr) ret.add(x); return ret; + }else if (javaObject instanceof byte[]){ + byte[] arr = (byte[]) javaObject; + for (int x : arr) ret.add(x); + return ret; } else if (javaObject instanceof long[]) { long[] arr = (long[]) javaObject; for (long x : arr) ret.add(x); @@ -410,83 +412,125 @@ public class PythonTypes { }; - public static final PythonType MEMORYVIEW = new PythonType("memoryview", BytePointer.class) { + public static final PythonType BYTES = new PythonType("bytes", byte[].class) { @Override - public BytePointer toJava(PythonObject pythonObject) { + public byte[] toJava(PythonObject pythonObject) { try (PythonGC gc = PythonGC.watch()) { - if (!(Python.isinstance(pythonObject, Python.memoryviewType()))) { - throw new PythonException("Expected memoryview. Received: " + pythonObject); + if (!(Python.isinstance(pythonObject, Python.bytesType()))) { + throw new PythonException("Expected bytes. Received: " + pythonObject); } PythonObject pySize = Python.len(pythonObject); - PythonObject ctypes = Python.importModule("ctypes"); - PythonObject charType = ctypes.attr("c_char"); - PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(), - pySize.getNativePythonObject())); - PythonObject fromBuffer = charArrayType.attr("from_buffer"); - if (pythonObject.attr("readonly").toBoolean()) { - pythonObject = Python.bytearray(pythonObject); + byte[] ret = new byte[pySize.toInt()]; + for (int i = 0; i < ret.length; i++) { + ret[i] = (byte)pythonObject.get(i).toInt(); } - PythonObject arr = fromBuffer.call(pythonObject); - PythonObject cast = ctypes.attr("cast"); - PythonObject voidPtrType = ctypes.attr("c_void_p"); - PythonObject voidPtr = cast.call(arr, voidPtrType); - long address = voidPtr.attr("value").toLong(); - long size = pySize.toLong(); - try { - Field addressField = Buffer.class.getDeclaredField("address"); - addressField.setAccessible(true); - Field capacityField = Buffer.class.getDeclaredField("capacity"); - capacityField.setAccessible(true); - ByteBuffer buff = ByteBuffer.allocateDirect(0).order(ByteOrder.nativeOrder()); - addressField.setLong(buff, address); - capacityField.setInt(buff, (int) size); - BytePointer ret = new BytePointer(buff); - ret.limit(size); - return ret; - - } catch (Exception e) { - throw new RuntimeException(e); - } - + return ret; } } @Override - public PythonObject toPython(BytePointer javaObject) { - long address = javaObject.address(); - long size = javaObject.limit(); - try (PythonGC gc = PythonGC.watch()) { - PythonObject ctypes = Python.importModule("ctypes"); - PythonObject charType = ctypes.attr("c_char"); - PythonObject pySize = new PythonObject(size); - PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(), - pySize.getNativePythonObject())); - PythonObject fromAddress = charArrayType.attr("from_address"); - PythonObject arr = fromAddress.call(new PythonObject(address)); - PythonObject memoryView = Python.memoryview(arr).attr("cast").call("b"); - PythonGC.keep(memoryView); - return memoryView; + public PythonObject toPython(byte[] javaObject) { + try(PythonGC gc = PythonGC.watch()){ + PythonObject ret = Python.bytes(LIST.toPython(LIST.adapt(javaObject))); + PythonGC.keep(ret); + return ret; } - } - @Override public boolean accepts(Object javaObject) { - return javaObject instanceof Pointer || javaObject instanceof DirectBuffer; + return javaObject instanceof byte[]; + } + @Override + public byte[] adapt(Object javaObject) { + if (javaObject instanceof byte[]){ + return (byte[])javaObject; + } + throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to byte[]"); } - @Override - public BytePointer adapt(Object javaObject) { - if (javaObject instanceof BytePointer) { - return (BytePointer) javaObject; - } else if (javaObject instanceof Pointer) { - return new BytePointer((Pointer) javaObject); - } else if (javaObject instanceof DirectBuffer) { - return new BytePointer((ByteBuffer) javaObject); - } else { - throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to BytePointer"); - } - } }; + /** + * Crashes on Adopt OpenJDK + * Use implementation in python4j-numpy instead for zero-copy byte buffers. + */ +// public static final PythonType MEMORYVIEW = new PythonType("memoryview", BytePointer.class) { +// @Override +// public BytePointer toJava(PythonObject pythonObject) { +// try (PythonGC gc = PythonGC.watch()) { +// if (!(Python.isinstance(pythonObject, Python.memoryviewType()))) { +// throw new PythonException("Expected memoryview. Received: " + pythonObject); +// } +// PythonObject pySize = Python.len(pythonObject); +// PythonObject ctypes = Python.importModule("ctypes"); +// PythonObject charType = ctypes.attr("c_char"); +// PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(), +// pySize.getNativePythonObject())); +// PythonObject fromBuffer = charArrayType.attr("from_buffer"); +// if (pythonObject.attr("readonly").toBoolean()) { +// pythonObject = Python.bytearray(pythonObject); +// } +// PythonObject arr = fromBuffer.call(pythonObject); +// PythonObject cast = ctypes.attr("cast"); +// PythonObject voidPtrType = ctypes.attr("c_void_p"); +// PythonObject voidPtr = cast.call(arr, voidPtrType); +// long address = voidPtr.attr("value").toLong(); +// long size = pySize.toLong(); +// try { +// Field addressField = Buffer.class.getDeclaredField("address"); +// addressField.setAccessible(true); +// Field capacityField = Buffer.class.getDeclaredField("capacity"); +// capacityField.setAccessible(true); +// ByteBuffer buff = ByteBuffer.allocateDirect(0).order(ByteOrder.nativeOrder()); +// addressField.setLong(buff, address); +// capacityField.setInt(buff, (int) size); +// BytePointer ret = new BytePointer(buff); +// ret.limit(size); +// return ret; +// +// } catch (Exception e) { +// throw new RuntimeException(e); +// } +// +// } +// } +// +// @Override +// public PythonObject toPython(BytePointer javaObject) { +// long address = javaObject.address(); +// long size = javaObject.limit(); +// try (PythonGC gc = PythonGC.watch()) { +// PythonObject ctypes = Python.importModule("ctypes"); +// PythonObject charType = ctypes.attr("c_char"); +// PythonObject pySize = new PythonObject(size); +// PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(), +// pySize.getNativePythonObject())); +// PythonObject fromAddress = charArrayType.attr("from_address"); +// PythonObject arr = fromAddress.call(new PythonObject(address)); +// PythonObject memoryView = Python.memoryview(arr).attr("cast").call("b"); +// PythonGC.keep(memoryView); +// return memoryView; +// } +// +// } +// +// @Override +// public boolean accepts(Object javaObject) { +// return javaObject instanceof Pointer || javaObject instanceof DirectBuffer; +// } +// +// @Override +// public BytePointer adapt(Object javaObject) { +// if (javaObject instanceof BytePointer) { +// return (BytePointer) javaObject; +// } else if (javaObject instanceof Pointer) { +// return new BytePointer((Pointer) javaObject); +// } else if (javaObject instanceof DirectBuffer) { +// return new BytePointer((ByteBuffer) javaObject); +// } else { +// throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to BytePointer"); +// } +// } +// }; + } diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariable.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonVariable.java similarity index 98% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariable.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonVariable.java index 3deb4d2e7..038904ec9 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariable.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonVariable.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; @lombok.Data public class PythonVariable { diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariables.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonVariables.java similarity index 98% rename from python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariables.java rename to python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonVariables.java index 32ae0b2f5..ed9ccff5d 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariables.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonVariables.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import java.util.ArrayList; import java.util.Arrays; diff --git a/python4j/python4j-core/src/main/resources/org/nd4j/python4j/pythonexec/__init__.py b/python4j/python4j-core/src/main/resources/org/nd4j/python4j/pythonexec/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python4j/python4j-core/src/main/resources/org/eclipse/python4j/pythonexec/pythonexec.py b/python4j/python4j-core/src/main/resources/org/nd4j/python4j/pythonexec/pythonexec.py similarity index 100% rename from python4j/python4j-core/src/main/resources/org/eclipse/python4j/pythonexec/pythonexec.py rename to python4j/python4j-core/src/main/resources/org/nd4j/python4j/pythonexec/pythonexec.py diff --git a/python4j/python4j-core/src/test/java/PythonBasicExecutionTest.java b/python4j/python4j-core/src/test/java/PythonBasicExecutionTest.java index 9f5b43dba..c26b5c874 100644 --- a/python4j/python4j-core/src/test/java/PythonBasicExecutionTest.java +++ b/python4j/python4j-core/src/test/java/PythonBasicExecutionTest.java @@ -15,9 +15,12 @@ ******************************************************************************/ -import org.eclipse.python4j.*; import org.junit.Assert; import org.junit.Test; +import org.nd4j.python4j.PythonContextManager; +import org.nd4j.python4j.PythonExecutioner; +import org.nd4j.python4j.PythonTypes; +import org.nd4j.python4j.PythonVariable; import javax.annotation.concurrent.NotThreadSafe; import java.util.*; diff --git a/python4j/python4j-core/src/test/java/PythonBufferTest.java b/python4j/python4j-core/src/test/java/PythonBufferTest.java deleted file mode 100644 index c59b86c15..000000000 --- a/python4j/python4j-core/src/test/java/PythonBufferTest.java +++ /dev/null @@ -1,113 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2020 Konduit K.K. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - - -import org.bytedeco.javacpp.BytePointer; -import org.bytedeco.javacpp.Loader; -import org.eclipse.python4j.*; -import org.junit.Assert; -import org.junit.Test; -import sun.nio.ch.DirectBuffer; - -import javax.annotation.concurrent.NotThreadSafe; -import java.nio.Buffer; -import java.nio.ByteBuffer; -import java.util.*; - -@NotThreadSafe -public class PythonBufferTest { - - @Test - public void testBuffer() { - ByteBuffer buff = ByteBuffer.allocateDirect(3); - buff.put((byte) 97); - buff.put((byte) 98); - buff.put((byte) 99); - buff.rewind(); - - BytePointer bp = new BytePointer(buff); - - List inputs = new ArrayList<>(); - inputs.add(new PythonVariable<>("buff", PythonTypes.MEMORYVIEW, buff)); - - List outputs = new ArrayList<>(); - outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); - outputs.add(new PythonVariable<>("s2", PythonTypes.STR)); - - String code = "s1 = ''.join(chr(c) for c in buff)\nbuff[2] += 2\ns2 = ''.join(chr(c) for c in buff)"; - - PythonExecutioner.exec(code, inputs, outputs); - Assert.assertEquals("abc", outputs.get(0).getValue()); - Assert.assertEquals("abe", outputs.get(1).getValue()); - Assert.assertEquals(101, buff.get(2)); - - } - @Test - public void testBuffer2() { - ByteBuffer buff = ByteBuffer.allocateDirect(3); - buff.put((byte) 97); - buff.put((byte) 98); - buff.put((byte) 99); - buff.rewind(); - - BytePointer bp = new BytePointer(buff); - - List inputs = new ArrayList<>(); - inputs.add(new PythonVariable<>("buff", PythonTypes.MEMORYVIEW, bp)); - - List outputs = new ArrayList<>(); - outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); - outputs.add(new PythonVariable<>("s2", PythonTypes.STR)); - - String code = "s1 = ''.join(chr(c) for c in buff)\nbuff[2] += 2\ns2 = ''.join(chr(c) for c in buff)"; - - PythonExecutioner.exec(code, inputs, outputs); - Assert.assertEquals("abc", outputs.get(0).getValue()); - Assert.assertEquals("abe", outputs.get(1).getValue()); - Assert.assertEquals(101, buff.get(2)); - - } - - @Test - public void testBuffer3() { - ByteBuffer buff = ByteBuffer.allocateDirect(3); - buff.put((byte) 97); - buff.put((byte) 98); - buff.put((byte) 99); - buff.rewind(); - - BytePointer bp = new BytePointer(buff); - - List inputs = new ArrayList<>(); - inputs.add(new PythonVariable<>("buff", PythonTypes.MEMORYVIEW, bp)); - - List outputs = new ArrayList<>(); - outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); - outputs.add(new PythonVariable<>("s2", PythonTypes.STR)); - outputs.add(new PythonVariable<>("buff2", PythonTypes.MEMORYVIEW)); - String code = "s1 = ''.join(chr(c) for c in buff)\nbuff[2] += 2\ns2 = ''.join(chr(c) for c in buff)\nbuff2=buff[1:]"; - PythonExecutioner.exec(code, inputs, outputs); - - Assert.assertEquals("abc", outputs.get(0).getValue()); - Assert.assertEquals("abe", outputs.get(1).getValue()); - Assert.assertEquals(101, buff.get(2)); - BytePointer outBuffer = (BytePointer) outputs.get(2).getValue(); - Assert.assertEquals(2, outBuffer.capacity()); - Assert.assertEquals((byte)98, outBuffer.get(0)); - Assert.assertEquals((byte)101, outBuffer.get(1)); - - } -} \ No newline at end of file diff --git a/python4j/python4j-core/src/test/java/PythonCollectionsTest.java b/python4j/python4j-core/src/test/java/PythonCollectionsTest.java index 7e63d9d28..ba4d8e14a 100644 --- a/python4j/python4j-core/src/test/java/PythonCollectionsTest.java +++ b/python4j/python4j-core/src/test/java/PythonCollectionsTest.java @@ -15,9 +15,9 @@ ******************************************************************************/ -import org.eclipse.python4j.PythonException; -import org.eclipse.python4j.PythonObject; -import org.eclipse.python4j.PythonTypes; +import org.nd4j.python4j.PythonException; +import org.nd4j.python4j.PythonObject; +import org.nd4j.python4j.PythonTypes; import org.junit.Assert; import org.junit.Test; diff --git a/python4j/python4j-core/src/test/java/PythonContextManagerTest.java b/python4j/python4j-core/src/test/java/PythonContextManagerTest.java index a4451764c..4961f94d8 100644 --- a/python4j/python4j-core/src/test/java/PythonContextManagerTest.java +++ b/python4j/python4j-core/src/test/java/PythonContextManagerTest.java @@ -16,9 +16,9 @@ ******************************************************************************/ -import org.eclipse.python4j.Python; -import org.eclipse.python4j.PythonContextManager; -import org.eclipse.python4j.PythonExecutioner; +import org.nd4j.python4j.Python; +import org.nd4j.python4j.PythonContextManager; +import org.nd4j.python4j.PythonExecutioner; import org.junit.Assert; import org.junit.Test; import javax.annotation.concurrent.NotThreadSafe; diff --git a/python4j/python4j-core/src/test/java/PythonGCTest.java b/python4j/python4j-core/src/test/java/PythonGCTest.java index 80b2e7f3c..11dd8e93a 100644 --- a/python4j/python4j-core/src/test/java/PythonGCTest.java +++ b/python4j/python4j-core/src/test/java/PythonGCTest.java @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -import org.eclipse.python4j.Python; -import org.eclipse.python4j.PythonGC; -import org.eclipse.python4j.PythonObject; +import org.nd4j.python4j.Python; +import org.nd4j.python4j.PythonGC; +import org.nd4j.python4j.PythonObject; import org.junit.Assert; import org.junit.Test; diff --git a/python4j/python4j-core/src/test/java/PythonJobTest.java b/python4j/python4j-core/src/test/java/PythonJobTest.java index b0f4233c9..4dad7f24f 100644 --- a/python4j/python4j-core/src/test/java/PythonJobTest.java +++ b/python4j/python4j-core/src/test/java/PythonJobTest.java @@ -14,10 +14,10 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -import org.eclipse.python4j.PythonContextManager; -import org.eclipse.python4j.PythonJob; -import org.eclipse.python4j.PythonTypes; -import org.eclipse.python4j.PythonVariable; +import org.nd4j.python4j.PythonContextManager; +import org.nd4j.python4j.PythonJob; +import org.nd4j.python4j.PythonTypes; +import org.nd4j.python4j.PythonVariable; import org.junit.Test; import java.util.ArrayList; diff --git a/python4j/python4j-core/src/test/java/PythonMultiThreadTest.java b/python4j/python4j-core/src/test/java/PythonMultiThreadTest.java index ec544b65f..b2f9089fa 100644 --- a/python4j/python4j-core/src/test/java/PythonMultiThreadTest.java +++ b/python4j/python4j-core/src/test/java/PythonMultiThreadTest.java @@ -14,10 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -import org.eclipse.python4j.*; +import org.nd4j.python4j.*; import org.junit.Assert; import org.junit.Test; - import javax.annotation.concurrent.NotThreadSafe; import java.util.ArrayList; import java.util.Arrays; diff --git a/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java b/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java index ae10ed8dc..94423f7de 100644 --- a/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java +++ b/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java @@ -15,12 +15,13 @@ ******************************************************************************/ -import org.eclipse.python4j.PythonException; -import org.eclipse.python4j.PythonObject; -import org.eclipse.python4j.PythonTypes; +import org.nd4j.python4j.*; import org.junit.Assert; import org.junit.Test; +import java.util.ArrayList; +import java.util.List; + public class PythonPrimitiveTypesTest { @Test @@ -78,5 +79,18 @@ public class PythonPrimitiveTypesTest { Assert.assertEquals(b, b3); } + @Test + public void testBytes() { + byte[] bytes = new byte[]{97, 98, 99}; + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("buff", PythonTypes.BYTES, bytes)); + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); + outputs.add(new PythonVariable<>("buff2", PythonTypes.BYTES)); + String code = "s1 = ''.join(chr(c) for c in buff)\nbuff2=b'def'"; + PythonExecutioner.exec(code, inputs, outputs); + Assert.assertEquals("abc", outputs.get(0).getValue()); + Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[])outputs.get(1).getValue()); + } } diff --git a/python4j/python4j-numpy/pom.xml b/python4j/python4j-numpy/pom.xml index bcce739ce..c631f67e3 100644 --- a/python4j/python4j-numpy/pom.xml +++ b/python4j/python4j-numpy/pom.xml @@ -4,7 +4,7 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> python4j-parent - org.eclipse + org.nd4j 1.0.0-SNAPSHOT 4.0.0 @@ -29,7 +29,7 @@ test - org.eclipse + org.nd4j python4j-core 1.0.0-SNAPSHOT diff --git a/python4j/python4j-numpy/src/main/java/org/eclipse/python4j/NumpyArray.java b/python4j/python4j-numpy/src/main/java/org/nd4j/python4j/NumpyArray.java similarity index 97% rename from python4j/python4j-numpy/src/main/java/org/eclipse/python4j/NumpyArray.java rename to python4j/python4j-numpy/src/main/java/org/nd4j/python4j/NumpyArray.java index 66fb76d23..b21dabd7c 100644 --- a/python4j/python4j-numpy/src/main/java/org/eclipse/python4j/NumpyArray.java +++ b/python4j/python4j-numpy/src/main/java/org/nd4j/python4j/NumpyArray.java @@ -15,26 +15,22 @@ ******************************************************************************/ -package org.eclipse.python4j; +package org.nd4j.python4j; import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.ArrayUtils; import org.bytedeco.cpython.PyObject; import org.bytedeco.cpython.PyTypeObject; import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.SizeTPointer; import org.bytedeco.numpy.PyArrayObject; import org.bytedeco.numpy.global.numpy; -import org.nd4j.linalg.api.buffer.BaseDataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.memory.MemoryWorkspaceManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; import java.io.File; diff --git a/python4j/python4j-numpy/src/main/resources/META-INF/services/org.eclipse.python4j.PythonType b/python4j/python4j-numpy/src/main/resources/META-INF/services/org.eclipse.python4j.PythonType deleted file mode 100644 index ae4d4640b..000000000 --- a/python4j/python4j-numpy/src/main/resources/META-INF/services/org.eclipse.python4j.PythonType +++ /dev/null @@ -1 +0,0 @@ -org.eclipse.python4j.NumpyArray \ No newline at end of file diff --git a/python4j/python4j-numpy/src/main/resources/META-INF/services/org.nd4j.python4j.PythonType b/python4j/python4j-numpy/src/main/resources/META-INF/services/org.nd4j.python4j.PythonType new file mode 100644 index 000000000..b0d2f1256 --- /dev/null +++ b/python4j/python4j-numpy/src/main/resources/META-INF/services/org.nd4j.python4j.PythonType @@ -0,0 +1 @@ +org.nd4j.python4j.NumpyArray \ No newline at end of file diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java index b7bd838b5..d76f759a6 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java @@ -15,13 +15,12 @@ ******************************************************************************/ -import org.eclipse.python4j.*; +import org.nd4j.python4j.*; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.concurrency.AffinityManager; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.nativeblas.OpaqueDataBuffer; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java index 99a050f63..64c417905 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java @@ -15,9 +15,9 @@ ******************************************************************************/ -import org.eclipse.python4j.PythonException; -import org.eclipse.python4j.PythonObject; -import org.eclipse.python4j.PythonTypes; +import org.nd4j.python4j.PythonException; +import org.nd4j.python4j.PythonObject; +import org.nd4j.python4j.PythonTypes; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java index d1c5ba761..96dd7274c 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java @@ -14,9 +14,9 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -import org.eclipse.python4j.Python; -import org.eclipse.python4j.PythonGC; -import org.eclipse.python4j.PythonObject; +import org.nd4j.python4j.Python; +import org.nd4j.python4j.PythonGC; +import org.nd4j.python4j.PythonObject; import org.junit.Assert; import org.junit.Test; import org.nd4j.linalg.factory.Nd4j; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java index 580f8643b..941072e45 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java @@ -1,7 +1,7 @@ -import org.eclipse.python4j.NumpyArray; -import org.eclipse.python4j.Python; -import org.eclipse.python4j.PythonGC; -import org.eclipse.python4j.PythonObject; +import org.nd4j.python4j.NumpyArray; +import org.nd4j.python4j.Python; +import org.nd4j.python4j.PythonGC; +import org.nd4j.python4j.PythonObject; import org.junit.Assert; import org.junit.Test; import org.nd4j.linalg.api.buffer.DataType; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java index 399b87fb1..dc087d0f8 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java @@ -14,7 +14,6 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -import org.eclipse.python4j.*; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; @@ -22,6 +21,7 @@ import org.junit.runners.Parameterized; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.python4j.*; import java.util.ArrayList; import java.util.List; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java index 52ccd1fd0..02eb99551 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -import org.eclipse.python4j.*; +import org.nd4j.python4j.*; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java index d3c649c8d..bd13a99d9 100644 --- a/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java @@ -15,18 +15,14 @@ ******************************************************************************/ -import org.eclipse.python4j.*; import org.junit.Assert; import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; -import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.python4j.NumpyArray; +import org.nd4j.python4j.PythonTypes; import javax.annotation.concurrent.NotThreadSafe; -import java.util.ArrayList; -import java.util.List; @NotThreadSafe public class PythonNumpyServiceLoaderTest { @@ -36,6 +32,4 @@ public class PythonNumpyServiceLoaderTest { Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.get("numpy.ndarray")); Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.getPythonTypeForJavaObject(Nd4j.zeros(1))); } - - } From 654afc810dfeadb4d38ccbf29294672e32506d00 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 25 Jun 2020 18:17:19 +0300 Subject: [PATCH 09/11] two more graphs for C++ tests Signed-off-by: raver119 --- libnd4j/tests_cpu/resources/simpleif_0_alt.fb | Bin 0 -> 7648 bytes libnd4j/tests_cpu/resources/simplewhile_1.fb | Bin 0 -> 12504 bytes 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 libnd4j/tests_cpu/resources/simpleif_0_alt.fb create mode 100644 libnd4j/tests_cpu/resources/simplewhile_1.fb diff --git a/libnd4j/tests_cpu/resources/simpleif_0_alt.fb b/libnd4j/tests_cpu/resources/simpleif_0_alt.fb new file mode 100644 index 0000000000000000000000000000000000000000..4a7e751c3f7eb16ddc5de7d693e363708a0dfd39 GIT binary patch literal 7648 zcmchce`u9e7{||a&RMILZMAC15Q$-eSKA*Y78hKwKTJkRNQ~Td_nMdPw#)8atq58% zV#Ek335oq7Au&W@NPk%P#~}X@kr>t=68<3}Vr3=$w4ZOE?|I*IcQ)Tmb8kKHz0dji zoaa2xbDs0Q?`<}YR+?25t4*z`GxesyG?|su_k(V50W^aYXa&aH&X`(Tsxix{iI!Fc z5!os~OF#pt2FpRduQ*uVxxp)^SkIWVcN#wZjbifeOyK9`KUfcUYfL-y_hlHI{x((pF04oJrqAlu328skVZ z;s_VWXN62pZcC<6FyoO-$x<9huJ#=#ljc~`x|yMTNu3{G2!cD{XTG;~e^KREz;W;) zI0*KGF)$2d`vQ>Nx5E8B?V!FHIMSYEd!aWs@LUeAHm!dP=;`mvu74uxpe~xPJ?x$R zD_J@tvUwAgIyFFqsg0X1xYnBfI(a{=`6iT^lpo5-lkIL+9>>(|9?10OwgJj-UkrlV z;P1he8#a9#o-37y!4y!;y0f*9@+cSvWg!1{!lh3$s0ZEPp7w`7I@EaNh)gZN)`HUD zuJS;pN7E|)t)p!HyE{`T<#zX}&(T7CZj9qv=T0CWp8(QX5}RssiXXJ5^+S1J(09Xp?F^qql3NdC?~@?T{wem7rmqdG11Zn~8i4jHJ2;ZO zIGUU1Mk1M#r9Dfo_8lja=2+6Y&i-485#<)Gzk{;Qi^`{d3iF=yJP59UOW-WfevgAm zFagGa{2+h%dmN>%yW5d$wT#1|d!+PYzTC4XwBH))(y0;nXJh?ScF?8`wCg^J>7zBI z>-D>FK9@c^_sZ=m7ua}r41&j6PLF8>l&dti3n=gM# zKwfbmm(BNN%F4ehiLqMveQ=q2j2(JCL0LA@Y z@NeQwk5n`NGG=DTbl#nq9{#<74LWMebx4JC^M;e8$n)@`R(%r^_ z9~cG0pbT`6cEV+^W>62hf&YC#awRYB`+)L+jT0%gGoPhep@sS?gS>*)(Ki_J52JR) zQ5UENL*(oex*O7S*?^X47KUaMP}{mTof- z^Cci{$Lgn@8I~F~U(Yp@U;^w1V_*cRy$EDv`K<#kzqNn{&9!SR}g3~ z@>fUXFP(sY;jbFaF97-K0N4vgz$?J{YCC25OZvA0`KlJQ1KCsegYIWe&|pt)XKAoE zY?nXuzWS6JB64-jL*%P!)?Tb2@rayU^_{13S`Cn7?xC!2L6cE^n!9NqW#3=T)TO7+ zu;%zAQR!U7b91T=-$hhw)FQF3pt(D*1waHt5|tH zmn*jCd+iw}e1p&@HN;J|J&)mjH_I-!=MaU`nvapWJ%|^ z`$G1AIx~>Z>?-8e59Bgg>W&ubb7Oq2J~pXPaI zwhyJ>d2`)fH<6n(cfRjBtC8*O0rJH&dLJCC0B{r_EWo?CN_ zHDzb5xzKyw8E>vza~1p6T*>lmIqMzp7UXN68l$+P%5z_xjYFL^B-xkU6 zW9dEC*4+7?_g}Yjs%FLCr_9lsm00nf{f@?r|C{l8u2FouXO1soba?>VD5l*VC$01! zUY(6}#s6}vr}MY7pZutt6PG_i{Z~|F>E8pSH%kXayZ(wQ*-Ltvm~No;rH}s840pQR+7&6Y?$I>kB)dW_ zVJS|JKIF-U+LOz*8hy&{dD+q?%8DUp3-yUUpkKR_uiXEj@`2lf?CtGgX^MIqaB?)) S*~@X7pc=-^+*331Q|2ECAxo10 literal 0 HcmV?d00001 diff --git a/libnd4j/tests_cpu/resources/simplewhile_1.fb b/libnd4j/tests_cpu/resources/simplewhile_1.fb new file mode 100644 index 0000000000000000000000000000000000000000..c4fa26e2ab10bf0fc1ba5c4eef5ebb57e3c281d0 GIT binary patch literal 12504 zcmb`Oe`sB28OPtWS(|oO7rS(eU66|wtXXWcnLlc&c-^v*l{m^;q?EEIORnb5n;#^% zOIE^mg&;x7D4IWlgb~IVWpt1~!ib|t(Ed?Mk*<9 zoSSsbdElG(yzl#*=Y8Mj`My8yIVVlV$rke&E^AGlNtt@nU>eO9@(bVy*j{5y6KRXC zfia&%tIoUHU^bGItgk8gqTD*r0BXTTkQ*2}Ryg`l=(^QxrF`$_jX8zUrq3C3Ep5zM zIx<9Kmp^SxV|5v$bb~M*lf%!$b&Xuz{M|62F5@J z)PkiQ)FJQux>O?Bx{W_$U=ExDGvFkc04g5>y+H2jAZ-Qeb3ND(zH<5Q3-|uuJNGag z!_YiW=X$dPgjp4*tbk*=@FlLy)rC{pcmOC>oF$6c5g^<7rBW$F+6His z*T*!eVwwWdaUvUWayNyT7veY88DH;cYajaM#{U5z`=CdIL$cDu)NBdT0?sZ_e3 zbkDCmzmGy0&(D`hFMvfb4^D$wFbz%s`MryDH%J4uueiwfPBbR@UfdqYjunpP3fbZA zLT+#%@W1#w+Ly~`cRh*56HWn^2XZ?)fwsL;+v1SYl>Q*l+vAgqPwJiIA(`@J2Xd~B(r(%5;EC9v&RWJ=sf(f8_%Wu0u z8YtdFU_WpoH?J0NmZ7+KI}>KL`m&;bp*F=?rPbnm3vOqqZv!nrzDxs+Aq6xaoXAd` zMCW106>C$rl9-3n*h@0b@BF+}(!6Q@y zGw&uFw?}?iDv3w?NZDS^53gB&YmUjk8rK|{1ykSzD1tGdc*x(Kq~ekD_aUzDZ{Pi+ z*SEg$hVyqFWh2L)DGYb_6u2z=UGv+np9hXTqY7WDmAiVE%D%XKn_z3PJwBl7)$`gU`46 z6(awOD=Ybb{%Z36ulsGX^I#wQl-9_nNOkW0o?UlFH8+aj444Os!3>xJlb{F`ANgPY zZUt$e`?|nk;Ka>^S_-mbIo(#C2Pv-YJa~+nisOT~##gJ8jj*TGnQ+3`Qxd8Iiv0nUL%Fb`a8 zPjao;jswN@AgTJH@yJFWxZU%)^lj{(@K5iRv0bg6+n0%AE59l|cvqhXET_TgdnZ~gY%XgVp?@6RTBy&X2S zcq3@z?6%ir{LZ$U4)d_R-m3XJ1FnF};2e-&=D@2!bti#hnJ4W8yMcV8_O*_Nq_o=lIq`?&{lDbfyi@?eTtE2XpFxUyh8~K^lmwDWH3t_-0(ToJQBzCEuoO-O2ou;fna*NH+ds3D_GN&i3Ydx(iwQ`W9=1 z;`9?z&CQSAx4aU^#GR{P39QDeu4-v*Lehq`>5N%3V4dr9);Ci|rP()?F1u3L3I4i>;1m<5XI2~Y%D z6Y@ZD?IhK>(x3qx0$08~_oc>j52&gfTg5S-9T{-=6KLXUol|LsR&(^=J&Hw6``%N~$9wk4#d9zeXks40m9sZeJGi9|iyD{zV%)UYQ z7eKYRsje3&ZtbMn6BV;Ma1h9!92>m2u&o^H$sQd!)^E$zXU$K49anK#qUEwf6E8|Z zvG5El^+omCt5tWI%sCYxYNL;9%}w!6`58DlMcM{hfVieLME%xyr4tsJQe9N$*#{Uw zO+6L5M|hD=>Wt(!zby58UoP)K*NI}*o9oXGjPUB>I8h*e8Op-_Op}l6r)!Vu3emjF zjXc&rRCxX=&;8_Qu-i_$7Q1RhG0KfRI@mvy&uaeON}#{u&9}4Jq4r$AH?Q}w@5rx@ zkgC1of4BRDVqFC1z#^Cjr@#!D1VvB)grmGqq{+*7U7!~@k?+J&fj4*AVN#W;U7ZOm zKY8x9>oLEULS&^n)ysz7+eJm?inY=){Xx`LBiE-%-8$2LYwhumo?oZ2xtBJ?`5i#~ zUy1uvSBU!Q)|qSAI!L|^xOtT3x)Fruk%p+UIFX+9ul*|>#X0;;sFRpyD2ukESccmn z>-p4osj_8}zs1#okwW)C!I=^E7ft^GLcRBN^73DQ|Y zY2bRu&y#rKxFDZeE_l53I8u!_fv(3J-%|MbS7muwA#|a3#d-0iiA=o;ko8~=907{S znGciaOFePOU^^+kbR}Rf%9+>LzqC$$@Nc_^i!=Lu&UAddbJJG&xd;})0=S*^qzS!x z&;`QJImHRti+i4_HRBuI!@2Hf__Sd-+ucjv$x3xjC!C57nk!1xXNkDtcyalo3xb3(9oduN0y>-zFKn)du%?VPZ| zub=SeYAyNbIU)a_apv@K_eri2S0rg`gQN24Sq=>~q@@+tN6zeq5TnpRM z5LFf@(p&AEP@tW5^rG#=_1Sl4TvOiJEPPhbQ_szeSh3K2(7NE>fqcw$VN^E8!6JyR z3q{JtKpu#XoutlJ8p9#b3!G>SZoi10TZ8?g^2}PTTsCE|8V}Wu>QDNPQ+!*wK27S@ zhg<(ko`)UyAt@eCVM{)Y@}rS5@k?WpXo%u^HIeRWTx{TtSoj-Fy(3c$Ga%Z&%ERrG z_4=SRD$#T7VfI1w_Xg`IY*;5d+C#*j=cv=SY2uvw_HYaHI;!_FdO9nu^!?kunEQUg zee#{&fmwf4d|PK;h`DDodx*XR(HgBU8D5Fe*@=$6RnY!y?b>fneioy97+v){(6!%3 zd=#VmU38tVt=-@qezUh8>#sep zzY?Rn6J4G6;?8fEV)VsBopXY|xO2u9)*MAsD}Xq&(R*ha9sA4NFO8sYcF!o-wrcX! z4FR14DDs}ih!a{bgSZ7ezhGlgBhE=j_-Zj&U&$Ji_WqXFxgMJ}&D^WKCg=-upZEF= zu+xo1PcS!fQV?f}ZcUND6JR@t_Ajtw{oUt{X;Y2vRlnqCUuxt7jZ0$p`5^B~T)peB zbl;`VvK6#tf6wU+bz`L0S{v|xjrPng;O4fz<+tP8=*=DJg#DL6S8bFfZLi|`Gr$ Date: Thu, 25 Jun 2020 22:23:47 -0400 Subject: [PATCH 10/11] RL4J: Use directly NeuralNet instances in DoubleDQN and StandardDQN (#499) Signed-off-by: Alexandre Boulanger --- .../agent/update/DQNNeuralNetUpdateRule.java | 14 ++-- .../sync/qlearning/TargetQNetworkSource.java | 28 -------- .../TDTargetAlgorithm/BaseDQNAlgorithm.java | 21 +++--- .../BaseTDTargetAlgorithm.java | 19 +++--- .../discrete/TDTargetAlgorithm/DoubleDQN.java | 10 +-- .../TDTargetAlgorithm/StandardDQN.java | 10 +-- .../IOutputNeuralNet.java} | 66 +++++++++++-------- .../deeplearning4j/rl4j/network/dqn/IDQN.java | 6 +- .../TDTargetAlgorithm/DoubleDQNTest.java | 40 +++++++---- .../TDTargetAlgorithm/StandardDQNTest.java | 45 ++++++++----- .../support/MockTargetQNetworkSource.java | 26 -------- 11 files changed, 126 insertions(+), 159 deletions(-) delete mode 100644 rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java rename rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/{learning/sync/qlearning/QNetworkSource.java => network/IOutputNeuralNet.java} (51%) delete mode 100644 rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java index 46123d645..98873b827 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/agent/update/DQNNeuralNetUpdateRule.java @@ -17,7 +17,6 @@ package org.deeplearning4j.rl4j.agent.update; import lombok.Getter; import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; @@ -28,13 +27,10 @@ import java.util.List; // Temporary class that will be replaced with a more generic class that delegates gradient computation // and network update to sub components. -public class DQNNeuralNetUpdateRule implements IUpdateRule>, TargetQNetworkSource { +public class DQNNeuralNetUpdateRule implements IUpdateRule> { - @Getter private final IDQN qNetwork; - - @Getter - private IDQN targetQNetwork; + private final IDQN targetQNetwork; private final int targetUpdateFrequency; private final ITDTargetAlgorithm tdTargetAlgorithm; @@ -47,8 +43,8 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule>, this.targetQNetwork = qNetwork.clone(); this.targetUpdateFrequency = targetUpdateFrequency; tdTargetAlgorithm = isDoubleDQN - ? new DoubleDQN(this, gamma, errorClamp) - : new StandardDQN(this, gamma, errorClamp); + ? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp) + : new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp); } @Override @@ -56,7 +52,7 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule>, DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch); qNetwork.fit(targets.getFeatures(), targets.getLabels()); if(++updateCount % targetUpdateFrequency == 0) { - targetQNetwork = qNetwork.clone(); + targetQNetwork.copy(qNetwork); } } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java deleted file mode 100644 index 34fd9c06e..000000000 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/TargetQNetworkSource.java +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning.sync.qlearning; - -import org.deeplearning4j.rl4j.network.dqn.IDQN; - -/** - * An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network - * - * @author Alexandre Boulanger - */ -public interface TargetQNetworkSource extends QNetworkSource { - IDQN getTargetQNetwork(); -} diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java index 3f27f954c..6cae384d5 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseDQNAlgorithm.java @@ -16,8 +16,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; -import org.deeplearning4j.rl4j.network.dqn.IDQN; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; /** @@ -28,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; */ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm { - private final TargetQNetworkSource qTargetNetworkSource; + private final IOutputNeuralNet targetQNetwork; /** * In litterature, this corresponds to Q{net}(s(t+1), a) @@ -40,23 +39,21 @@ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm { */ protected INDArray targetQNetworkNextObservation; - protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) { - super(qTargetNetworkSource, gamma); - this.qTargetNetworkSource = qTargetNetworkSource; + protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) { + super(qNetwork, gamma); + this.targetQNetwork = targetQNetwork; } - protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { - super(qTargetNetworkSource, gamma, errorClamp); - this.qTargetNetworkSource = qTargetNetworkSource; + protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) { + super(qNetwork, gamma, errorClamp); + this.targetQNetwork = targetQNetwork; } @Override protected void initComputation(INDArray observations, INDArray nextObservations) { super.initComputation(observations, nextObservations); - qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations); - - IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork(); + qNetworkNextObservation = qNetwork.output(nextObservations); targetQNetworkNextObservation = targetQNetwork.output(nextObservations); } } diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java index ca4beb47e..e0ede18d7 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/BaseTDTargetAlgorithm.java @@ -17,7 +17,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; @@ -30,7 +30,7 @@ import java.util.List; */ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm { - protected final QNetworkSource qNetworkSource; + protected final IOutputNeuralNet qNetwork; protected final double gamma; private final double errorClamp; @@ -38,12 +38,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithmerrorClamp away from the previous value. Double.NaN will disable the clamping. */ - protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) { - this.qNetworkSource = qNetworkSource; + protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma, double errorClamp) { + this.qNetwork = qNetwork; this.gamma = gamma; this.errorClamp = errorClamp; @@ -52,12 +52,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm transition = transitions.get(i); double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal()); diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java index 3203af1b8..caeb85fb6 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQN.java @@ -16,7 +16,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,12 +32,12 @@ public class DoubleDQN extends BaseDQNAlgorithm { // In litterature, this corresponds to: max_{a}Q(s_{t+1}, a) private INDArray maxActionsFromQNetworkNextObservation; - public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) { - super(qTargetNetworkSource, gamma); + public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) { + super(qNetwork, targetQNetwork, gamma); } - public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { - super(qTargetNetworkSource, gamma, errorClamp); + public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) { + super(qNetwork, targetQNetwork, gamma, errorClamp); } @Override diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java index 8c03c8de9..6cd047c74 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQN.java @@ -16,7 +16,7 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -32,12 +32,12 @@ public class StandardDQN extends BaseDQNAlgorithm { // In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a) private INDArray maxActionsFromQTargetNextObservation; - public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) { - super(qTargetNetworkSource, gamma); + public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) { + super(qNetwork, targetQNetwork, gamma); } - public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { - super(qTargetNetworkSource, gamma, errorClamp); + public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) { + super(qNetwork, targetQNetwork, gamma, errorClamp); } @Override diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QNetworkSource.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java similarity index 51% rename from rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QNetworkSource.java rename to rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java index e22d368e4..58e219ea0 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/learning/sync/qlearning/QNetworkSource.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/IOutputNeuralNet.java @@ -1,28 +1,38 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.rl4j.learning.sync.qlearning; - -import org.deeplearning4j.rl4j.network.dqn.IDQN; - -/** - * An interface for all implementations capable of supplying a Q-Network - * - * @author Alexandre Boulanger - */ -public interface QNetworkSource { - IDQN getQNetwork(); -} +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.rl4j.network; + +import org.deeplearning4j.rl4j.observation.Observation; +import org.nd4j.linalg.api.ndarray.INDArray; + +/** + * An interface defining the output aspect of a {@link NeuralNet}. + */ +public interface IOutputNeuralNet { + /** + * Compute the output for the supplied observation. + * @param observation An {@link Observation} + * @return The ouptut of the network + */ + INDArray output(Observation observation); + + /** + * Compute the output for the supplied batch. + * @param batch + * @return The ouptut of the network + */ + INDArray output(INDArray batch); +} \ No newline at end of file diff --git a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java index af295d202..daed646c5 100644 --- a/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java +++ b/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/IDQN.java @@ -17,6 +17,7 @@ package org.deeplearning4j.rl4j.network.dqn; import org.deeplearning4j.nn.gradient.Gradient; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.observation.Observation; import org.nd4j.linalg.api.ndarray.INDArray; @@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray; * This neural net quantify the value of each action given a state * */ -public interface IDQN extends NeuralNet { +public interface IDQN extends NeuralNet, IOutputNeuralNet { boolean isRecurrent(); @@ -37,9 +38,6 @@ public interface IDQN extends NeuralNet { void fit(INDArray input, INDArray[] labels); - INDArray output(INDArray batch); - INDArray output(Observation observation); - INDArray[] outputAll(INDArray batch); NN clone(); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java index 798bddf0d..0f03a5370 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/DoubleDQNTest.java @@ -1,10 +1,13 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; -import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -13,16 +16,29 @@ import java.util.ArrayList; import java.util.List; import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +@RunWith(MockitoJUnitRunner.class) public class DoubleDQNTest { + @Mock + IOutputNeuralNet qNetworkMock; + + @Mock + IOutputNeuralNet targetQNetworkMock; + + + @Before + public void setup() { + when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); + } + @Test public void when_isTerminal_expect_rewardValueAtIdx0() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); List> transitions = new ArrayList>() { { @@ -31,7 +47,7 @@ public class DoubleDQNTest { } }; - DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); + DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); @@ -46,9 +62,7 @@ public class DoubleDQNTest { public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(-1.0); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0)); List> transitions = new ArrayList>() { { @@ -57,7 +71,7 @@ public class DoubleDQNTest { } }; - DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); + DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); @@ -72,9 +86,7 @@ public class DoubleDQNTest { public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(-1.0); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); + when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0)); List> transitions = new ArrayList>() { { @@ -87,7 +99,7 @@ public class DoubleDQNTest { } }; - DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); + DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java index 3e3701669..6aead9e76 100644 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java +++ b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/qlearning/discrete/TDTargetAlgorithm/StandardDQNTest.java @@ -1,10 +1,13 @@ package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.Transition; -import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; -import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource; +import org.deeplearning4j.rl4j.network.IOutputNeuralNet; import org.deeplearning4j.rl4j.observation.Observation; +import org.junit.Before; import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.factory.Nd4j; @@ -12,17 +15,31 @@ import org.nd4j.linalg.factory.Nd4j; import java.util.ArrayList; import java.util.List; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +@RunWith(MockitoJUnitRunner.class) public class StandardDQNTest { + + @Mock + IOutputNeuralNet qNetworkMock; + + @Mock + IOutputNeuralNet targetQNetworkMock; + + + @Before + public void setup() { + when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); + when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]); + } + + @Test public void when_isTerminal_expect_rewardValueAtIdx0() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); - List> transitions = new ArrayList>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), @@ -30,7 +47,7 @@ public class StandardDQNTest { } }; - StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); + StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); @@ -45,10 +62,6 @@ public class StandardDQNTest { public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); - List> transitions = new ArrayList>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), @@ -56,7 +69,7 @@ public class StandardDQNTest { } }; - StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); + StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); @@ -71,10 +84,6 @@ public class StandardDQNTest { public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { // Assemble - MockDQN qNetwork = new MockDQN(); - MockDQN targetQNetwork = new MockDQN(); - MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork); - List> transitions = new ArrayList>() { { add(buildTransition(buildObservation(new double[]{1.1, 2.2}), @@ -86,7 +95,7 @@ public class StandardDQNTest { } }; - StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); + StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5); // Act DataSet result = sut.computeTDTargets(transitions); diff --git a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java b/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java deleted file mode 100644 index ce756aa88..000000000 --- a/rl4j/rl4j-core/src/test/java/org/deeplearning4j/rl4j/learning/sync/support/MockTargetQNetworkSource.java +++ /dev/null @@ -1,26 +0,0 @@ -package org.deeplearning4j.rl4j.learning.sync.support; - -import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; -import org.deeplearning4j.rl4j.network.dqn.IDQN; - -public class MockTargetQNetworkSource implements TargetQNetworkSource { - - - private final IDQN qNetwork; - private final IDQN targetQNetwork; - - public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) { - this.qNetwork = qNetwork; - this.targetQNetwork = targetQNetwork; - } - - @Override - public IDQN getTargetQNetwork() { - return targetQNetwork; - } - - @Override - public IDQN getQNetwork() { - return qNetwork; - } -} From 69ebc96068409215a020cb8360c61b5ccf451dbe Mon Sep 17 00:00:00 2001 From: Abdelrauf Date: Fri, 26 Jun 2020 11:03:46 +0400 Subject: [PATCH 11/11] Pi build and initial ArmCompute library support (#494) * - raspberry Pi build and ArmCompute library support - initial ArmCompute platform implementations (Maxpool2d AvgPool2d for float32) Signed-off-by: AbdelRauf * - Build script for pi - small changes Signed-off-by: AbdelRauf --- libnd4j/CMakeLists.txt | 17 ++ libnd4j/blas/CMakeLists.txt | 14 +- libnd4j/cmake/FindARMCOMPUTE.cmake | 74 +++++ libnd4j/include/config.h.in | 2 + .../ops/declarable/helpers/cpu/lup.cpp | 4 +- .../platform/armcompute/armcomputeUtils.cpp | 278 ++++++++++++++++++ .../platform/armcompute/armcomputeUtils.h | 133 +++++++++ .../platform/armcompute/avgpooling2d.cpp | 106 +++++++ .../platform/armcompute/maxpooling2d.cpp | 106 +++++++ libnd4j/pi_build.sh | 185 ++++++++++++ libnd4j/tests_cpu/layers_tests/CMakeLists.txt | 11 +- .../layers_tests/ConvolutionTests2.cpp | 20 +- .../layers_tests/DeclarableOpsTests19.cpp | 5 +- .../layers_tests/SessionLocalTests.cpp | 93 ------ .../tests_cpu/libnd4j_tests/CMakeLists.txt | 24 +- 15 files changed, 962 insertions(+), 110 deletions(-) create mode 100644 libnd4j/cmake/FindARMCOMPUTE.cmake create mode 100644 libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp create mode 100644 libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h create mode 100644 libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp create mode 100644 libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp create mode 100755 libnd4j/pi_build.sh delete mode 100644 libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt index 106401b31..0631763c2 100755 --- a/libnd4j/CMakeLists.txt +++ b/libnd4j/CMakeLists.txt @@ -131,6 +131,23 @@ if(NOT SD_CUDA) endif() endif() +#arm-compute entry +if(${HELPERS_armcompute}) + find_package(ARMCOMPUTE REQUIRED) + + if(ARMCOMPUTE_FOUND) + message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}") + set(HAVE_ARMCOMPUTE 1) + # Add preprocessor definition for ARM Compute NEON + add_definitions(-DARMCOMPUTENEON_ENABLED) + #build our library with neon support + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") + include_directories(${ARMCOMPUTE_INCLUDE}) + message("----${ARMCOMPUTE_INCLUDE}---") + endif() + +endif() + # new mkl-dnn entry if (${HELPERS_mkldnn}) diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index fb1dc066e..b6bd1f7c0 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -146,6 +146,10 @@ if (HAVE_MKLDNN) file(GLOB_RECURSE CUSTOMOPS_MKLDNN_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h) endif() +if(HAVE_ARMCOMPUTE) + file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/*.h) +endif() + if(SD_CUDA) message("Build cublas") find_package(CUDA) @@ -243,7 +247,7 @@ if(SD_CUDA) ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} - ${CUSTOMOPS_GENERIC_SOURCES} + ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ) if (WIN32) @@ -351,8 +355,8 @@ elseif(SD_CPU) add_definitions(-D__CPUBLAS__=true) add_library(samediff_obj OBJECT ${LEGACY_SOURCES} ${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} - ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} - ${OPS_SOURCES} ${PERF_SOURCES}) + ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} + ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) if(IOS) add_library(${SD_LIBRARY_NAME} STATIC $) else() @@ -378,12 +382,12 @@ elseif(SD_CPU) if (NOT BLAS_LIBRARIES) set(BLAS_LIBRARIES "") endif() - target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES}) + target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES}) if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}") message(STATUS "Building minifier...") add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp) - target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES}) + target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES}) endif() if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9) diff --git a/libnd4j/cmake/FindARMCOMPUTE.cmake b/libnd4j/cmake/FindARMCOMPUTE.cmake new file mode 100644 index 000000000..ae0e1fbba --- /dev/null +++ b/libnd4j/cmake/FindARMCOMPUTE.cmake @@ -0,0 +1,74 @@ +################################################################################ +# Copyright (c) 2020 Konduit K.K. +# +# This program and the accompanying materials are made available under the +# terms of the Apache License, Version 2.0 which is available at +# https://www.apache.org/licenses/LICENSE-2.0. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + + +### Find ARM COMPUTE LIBRARY STATIC libraries + +SET (COMPUTE_INCLUDE_DIRS + /usr/include + ${ARMCOMPUTE_ROOT} + ${ARMCOMPUTE_ROOT}/include + ${ARMCOMPUTE_ROOT}/applications + ${ARMCOMPUTE_ROOT}/applications/arm_compute +) + + +SET (COMPUTE_LIB_DIRS + /lib + /usr/lib + ${ARMCOMPUTE_ROOT} + ${ARMCOMPUTE_ROOT}/lib + ${ARMCOMPUTE_ROOT}/build +) + +find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h + PATHS ${COMPUTE_INCLUDE_DIRS} + NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) + +find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h) + +find_path(HALF_INCLUDE half/half.hpp) +find_path(HALF_INCLUDE half/half.hpp + PATHS ${ARMCOMPUTE_ROOT}/include + NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) +include_directories(SYSTEM ${HALF_INCLUDE}) + +# Find the Arm Compute libraries if not already specified +if (NOT DEFINED ARMCOMPUTE_LIBRARIES) + + find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static + PATHS ${COMPUTE_LIB_DIRS} + PATH_SUFFIXES "Release" + NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) + + find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static + PATHS ${COMPUTE_LIB_DIRS} + PATH_SUFFIXES "Release" + NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) + # In case it wasn't there, try a default search (will work in cases where + # the library has been installed into a standard location) + find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static) + find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static) + + set(ARMCOMPUTE_LIBRARIES ${ARMCOMPUTE_LIBRARY} ${ARMCOMPUTE_CORE_LIBRARY} ) +endif() + + +INCLUDE(FindPackageHandleStandardArgs) + +FIND_PACKAGE_HANDLE_STANDARD_ARGS(ARMCOMPUTE REQUIRED_VARS ARMCOMPUTE_INCLUDE ARMCOMPUTE_LIBRARIES) + diff --git a/libnd4j/include/config.h.in b/libnd4j/include/config.h.in index 1e63552d0..c858dd765 100644 --- a/libnd4j/include/config.h.in +++ b/libnd4j/include/config.h.in @@ -3,6 +3,8 @@ #cmakedefine HAVE_MKLDNN +#cmakedefine HAVE_ARMCOMPUTE + #cmakedefine MKLDNN_PATH "@MKLDNN_PATH@" #cmakedefine HAVE_OPENBLAS diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 8f45c696b..7e66d4b11 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -215,7 +215,9 @@ namespace helpers { auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]); auto result = -1; //auto loop = PRAGMA_THREADS_FOR { - auto start = column, stop = rowNum, increment = 1; + auto start = column; + auto stop = rowNum; + auto increment = 1; for (auto rowCounter = start; rowCounter < stop; rowCounter++) { Nd4jLong xPos[] = {rowCounter, column}; auto xIndex = shape::getOffset(compoundShape, xPos, 0); diff --git a/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp new file mode 100644 index 000000000..66b472252 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp @@ -0,0 +1,278 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // Created by Abdelrauf 2020 + + +#include +#include +#include +#include +#include +#include + +#include "armcomputeUtils.h" + + +namespace sd { +namespace ops { +namespace platforms { + + + +Arm_DataType getArmType ( const DataType &dType){ + Arm_DataType ret; + switch (dType){ + case HALF : + ret = Arm_DataType::F16; + break; + case FLOAT32 : + ret = Arm_DataType::F32; + break; + case DOUBLE : + ret = Arm_DataType::F64; + break; + case INT8 : + ret = Arm_DataType::S8; + break; + case INT16 : + ret = Arm_DataType::S16; + break; + case INT32 : + ret = Arm_DataType::S32; + break; + case INT64 : + ret = Arm_DataType::S64; + break; + case UINT8 : + ret = Arm_DataType::U8; + break; + case UINT16 : + ret = Arm_DataType::U16; + break; + case UINT32 : + ret = Arm_DataType::U32; + break; + case UINT64 : + ret = Arm_DataType::U64; + break; + case BFLOAT16 : + ret = Arm_DataType::BFLOAT16; + break; + default: + ret = Arm_DataType::UNKNOWN; + }; + + return ret; +} +bool isArmcomputeFriendly(const NDArray& arr) { + auto dType = getArmType(arr.dataType()); + int rank = (int)(arr.rankOf()); + return dType != Arm_DataType::UNKNOWN && + rank<=arm_compute::MAX_DIMS && + arr.ordering() == 'c' && + arr.ews()==1 && + shape::strideDescendingCAscendingF(arr.shapeInfo()) == true; +} + +Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases,sd::DataType ndArrayType, arm_compute::DataLayout layout) { + constexpr int numChannels = 1; + auto dType = getArmType(ndArrayType); + + Arm_TensorShape shape; + shape.set_num_dimensions(rank); + for (int i = 0, j = rank - 1; i < rank; i++, j--) { + shape[i] = static_cast(bases[j]); + } + // fill the rest unused with 1 + for (int i = rank; i < arm_compute::MAX_DIMS; i++) { + shape[i] = 1; + } + + return Arm_TensorInfo(shape, numChannels, dType, layout); +} + +Arm_TensorInfo getArmTensorInfo(const NDArray& arr, + arm_compute::DataLayout layout) { + auto dType = getArmType(arr.dataType()); + + // + constexpr int numChannels = 1; + int rank = (int)(arr.rankOf()); + auto bases = arr.shapeOf(); + auto arrStrides = arr.stridesOf(); + + // https://arm-software.github.io/ComputeLibrary/v20.05/_dimensions_8h_source.xhtml + // note: underhood it is stored as std::array _id; + // TensorShape is derived from Dimensions + // as well as Strides : public Dimensions + Arm_TensorShape shape; + Arm_Strides strides; + shape.set_num_dimensions(rank); + strides.set_num_dimensions(rank); + size_t element_size = arm_compute::data_size_from_type(dType); + for (int i = 0, j = rank - 1; i < rank; i++, j--) { + shape[i] = static_cast(bases[j]); + strides[i] = static_cast(arrStrides[j]) * element_size; + } + // fill the rest unused with 1 + for (int i = rank; i < arm_compute::MAX_DIMS; i++) { + shape[i] = 1; + } + size_t total_size; + size_t size_ind = rank - 1; + total_size = shape[size_ind] * strides[size_ind]; + + Arm_TensorInfo info; + info.init(shape, numChannels, dType, strides, 0, total_size); + info.set_data_layout(layout); + + return info; +} + +Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout) { + // - Ownership of the backing memory is not transferred to the tensor itself. + // - The tensor mustn't be memory managed. + // - Padding requirements should be accounted by the client code. + // In other words, if padding is required by the tensor after the function + // configuration step, then the imported backing memory should account for it. + // Padding can be checked through the TensorInfo::padding() interface. + + // Import existing pointer as backing memory + auto info = getArmTensorInfo(arr, layout); + Arm_Tensor tensor; + tensor.allocator()->init(info); + void* buff = (void*)arr.buffer(); + tensor.allocator()->import_memory(buff); + return tensor; +} + +void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output) { + //only for C order + //only for C order + if (output.ordering() != 'c') return; + auto shapeInfo = output.shapeInfo(); + auto bases = &(shapeInfo[1]); + Nd4jLong rank = shapeInfo[0]; + auto strides = output.stridesOf(); + int width = bases[rank - 1]; + uint8_t* outputBuffer = (uint8_t*)output.buffer(); + size_t offset = 0; + arm_compute::Window window; + arm_compute::Iterator tensor_it(&inTensor, window); + + int element_size = inTensor.info()->element_size(); + window.use_tensor_dimensions(inTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY); + +// if (output.ews() == 1) { + auto copySize = width * element_size; + auto dest = outputBuffer; + arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id) + { + auto src = tensor_it.ptr(); + memcpy(dest, src, copySize); + dest += copySize; + }, + tensor_it); + // } + // else { + // Nd4jLong coords[MAX_RANK] = {}; + // if(strides[rank-1]!=1){ + // throw std::runtime_error( "not implemented for subarrays whose last stride is not 1"); + // //TODO: implement to work with all subarrays properly + // } + // arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id) + // { + // auto src = tensor_it.ptr(); + // auto dest = outputBuffer + offset * element_size; + // memcpy(dest, src, width * element_size); + // offset = sd::inc_coords(bases, strides, coords, offset, rank, 1); + // }, + // tensor_it); + // } +} + +void copyToTensor(const NDArray& input, Arm_Tensor& outTensor) { + //only for C order + if (input.ordering() != 'c') return; + auto shapeInfo = input.shapeInfo(); + auto bases = &(shapeInfo[1]); + Nd4jLong rank = shapeInfo[0]; + auto strides = input.stridesOf(); + uint8_t *inputBuffer = (uint8_t*)input.buffer(); + int width = bases[rank - 1]; + size_t offset = 0; + arm_compute::Window window; + arm_compute::Iterator tensor_it(&outTensor, window); + int element_size = outTensor.info()->element_size(); + + window.use_tensor_dimensions(outTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY); + +// if (input.ews() == 1) { + + auto copySize = width * element_size; + auto src = inputBuffer; + arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id) + { + auto dest = tensor_it.ptr(); + memcpy(dest,src, copySize); + src += copySize; + }, + tensor_it); +// } +// else { +// Nd4jLong coords[MAX_RANK] = {}; +// if(strides[rank-1]!=1){ +// throw std::runtime_error( "not implemented for subarrays whose last stride is not 1"); +// //TODO: implement to work with all subarrays properly +// } +// arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id) +// { +// auto dest = tensor_it.ptr(); +// auto src = inputBuffer + offset * element_size; +// offset = sd::inc_coords(bases, strides, coords, offset, rank, 1); +// }, +// tensor_it); +// } +} + + +// armcompute should be built with debug option +void print_tensor(Arm_ITensor& tensor, const char* msg) { + auto info = tensor.info(); + auto padding = info->padding(); + std::cout << msg << "\ntotal: " << info->total_size() << "\n"; + + for (int i = 0; i < arm_compute::MAX_DIMS; i++) { + std::cout << info->dimension(i) << ","; + } + std::cout << std::endl; + for (int i = 0; i < arm_compute::MAX_DIMS; i++) { + std::cout << info->strides_in_bytes()[i] << ","; + } + std::cout << "\npadding: l " << padding.left << ", r " << padding.right + << ", t " << padding.top << ", b " << padding.bottom << std::endl; + +#ifdef ARM_COMPUTE_ASSERTS_ENABLED + //note it did not print correctly fro NHWC + std::cout << msg << ":\n"; + tensor.print(std::cout); + std::cout << std::endl; +#endif +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h new file mode 100644 index 000000000..72a4e6e89 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h @@ -0,0 +1,133 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +#ifndef DEV_TESTSARMCOMPUTEUTILS_H +#define DEV_TESTSARMCOMPUTEUTILS_H + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace samediff; + + +namespace sd { + namespace ops { + namespace platforms { + + using Arm_DataType = arm_compute::DataType; + using Arm_Tensor = arm_compute::Tensor; + using Arm_ITensor = arm_compute::ITensor; + using Arm_TensorInfo = arm_compute::TensorInfo; + using Arm_TensorShape = arm_compute::TensorShape; + using Arm_Strides = arm_compute::Strides; + /** + * Here we actually declare our platform helpers + */ + + + DECLARE_PLATFORM(maxpool2d, ENGINE_CPU); + + DECLARE_PLATFORM(avgpool2d, ENGINE_CPU); + + //utils + Arm_DataType getArmType(const sd::DataType& dType); + + Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases, sd::DataType ndArrayType, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN); + + Arm_TensorInfo getArmTensorInfo(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN); + + Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN); + + void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output); + void copyToTensor(const NDArray& input, Arm_Tensor& outTensor); + void print_tensor(Arm_ITensor& tensor, const char* msg); + bool isArmcomputeFriendly(const NDArray& arr); + + + template + class ArmFunction { + public: + + template + void configure(NDArray *input , NDArray *output, arm_compute::DataLayout layout, Args&& ...args) { + + auto inInfo = getArmTensorInfo(*input, layout); + auto outInfo = getArmTensorInfo(*output, layout); + in.allocator()->init(inInfo); + out.allocator()->init(outInfo); + armFunction.configure(&in,&out,std::forward(args) ...); + if (in.info()->has_padding()) { + //allocate and copy + in.allocator()->allocate(); + //copy + copyToTensor(*input, in); + + } + else { + //import buffer + void* buff = input->buffer(); + in.allocator()->import_memory(buff); + } + if (out.info()->has_padding()) { + //store pointer to our array to copy after run + out.allocator()->allocate(); + outNd = output; + } + else { + //import + void* buff = output->buffer(); + out.allocator()->import_memory(buff); + } + + } + + void run() { + armFunction.run(); + if (outNd) { + copyFromTensor(out, *outNd); + } + } + + private: + Arm_Tensor in; + Arm_Tensor out; + NDArray *outNd=nullptr; + F armFunction{}; + }; + } + } +} + + + +#endif //DEV_TESTSARMCOMPUTEUTILS_H diff --git a/libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp new file mode 100644 index 000000000..d8413104d --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp @@ -0,0 +1,106 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // Created by Abdelrauf (rauf@konduit.ai) 2020 + +#include +#include +#include +#include + + +#include "armcomputeUtils.h" + + +namespace sd { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = INT_ARG(8); + const auto extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D ARMCOMPUTE op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D ARMCOMPUTE op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + bool exclude_padding= (extraParam0 == 0) ? true : false; + + auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC; + + // Calculate individual paddings + unsigned int pad_left, pad_top, pad_right, pad_bottom; + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + if(paddingMode){ + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + } + pad_left = pW; + pad_top = pH; + pad_right = (oW - 1) * sW - iW + kW - pW ; + pad_bottom = (oH - 1) * sH - iH + kH - pH ; + +#if 0 + nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH + , pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0); +#endif + auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR); + auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::AVG, arm_compute::Size2D(kW, kH), dataLayout, poolPad, exclude_padding); + ArmFunction pool; + pool.configure(input,output, dataLayout, poolInfo); + + pool.run(); // run function + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + // Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + auto dTypeInput = getArmType(input->dataType()); + auto dTypeOutput = getArmType(output->dataType()); + bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output) + && (dTypeInput ==Arm_DataType::F32) + && (dTypeOutput ==Arm_DataType::F32); + return is_supported; +} + + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp new file mode 100644 index 000000000..cd6779628 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp @@ -0,0 +1,106 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // Created by Abdelrauf 2020 + + +#include +#include +#include +#include + + +#include "armcomputeUtils.h" + + +namespace sd { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D ARMCOMPUTE OP: input array should have rank of 4, but got %i instead", input->rankOf()); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + const int kH = INT_ARG(0); + const int kW = INT_ARG(1); + const int sH = INT_ARG(2); + const int sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + const int paddingMode = INT_ARG(8); + // const int extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC; + + // Calculate individual paddings + unsigned int pad_left, pad_top, pad_right, pad_bottom; + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + if(paddingMode){ + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + } + pad_left = pW; + pad_top = pH; + pad_right = (oW - 1) * sW - iW + kW - pW ; + pad_bottom = (oH - 1) * sH - iH + kH - pH ; +#if 0 + nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH + , pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0); +#endif + + auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR); + auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::MAX, arm_compute::Size2D(kW, kH), dataLayout, poolPad); + ArmFunction pool; + + pool.configure(input,output, dataLayout, poolInfo); + + pool.run(); // run function + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + // Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + auto dTypeInput = getArmType(input->dataType()); + auto dTypeOutput = getArmType(output->dataType()); + bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output) + && (dTypeInput ==Arm_DataType::F32) + && (dTypeOutput ==Arm_DataType::F32); + return is_supported; +} + + + +} +} +} diff --git a/libnd4j/pi_build.sh b/libnd4j/pi_build.sh new file mode 100755 index 000000000..f96c3f1f1 --- /dev/null +++ b/libnd4j/pi_build.sh @@ -0,0 +1,185 @@ +#!/bin/bash +TARGET=armv7-a +BLAS_TARGET_NAME=ARMV7 +ARMCOMPUTE_TARGET=armv7a +#BASE_DIR=${HOME}/pi +#https://stackoverflow.com/questions/59895/how-to-get-the-source-directory-of-a-bash-script-from-within-the-script-itself +SOURCE="${BASH_SOURCE[0]}" +ARMCOMPUTE_DEBUG=1 +LIBND4J_BUILD_MODE=Release +while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink + DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )" + SOURCE="$(readlink "$SOURCE")" + [[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located +done +BASE_DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )" +CMAKE=cmake #/snap/bin/cmake + +mkdir -p ${BASE_DIR}/helper_bin/ + +CROSS_COMPILER_URL=https://sourceforge.net/projects/raspberry-pi-cross-compilers/files/Raspberry%20Pi%20GCC%20Cross-Compiler%20Toolchains/Buster/GCC%208.3.0/Raspberry%20Pi%203A%2B%2C%203B%2B%2C%204/cross-gcc-8.3.0-pi_3%2B.tar.gz/download +CROSS_COMPILER_DIR=${BASE_DIR}/helper_bin/cross_compiler + +SCONS_LOCAL_URL=http://prdownloads.sourceforge.net/scons/scons-local-3.1.1.tar.gz +SCONS_LOCAL_DIR=${BASE_DIR}/helper_bin/scons_local + +THIRD_PARTY=${BASE_DIR}/third_party_libs + +ARMCOMPUTE_GIT_URL=https://github.com/ARM-software/ComputeLibrary.git +ARMCOMPUTE_TAG=v20.05 +ARMCOMPUTE_DIR=${THIRD_PARTY}/arm_compute_dir + +OPENBLAS_GIT_URL="https://github.com/xianyi/OpenBLAS.git" +OPENBLAS_DIR=${THIRD_PARTY}/OpenBLAS + + +LIBND4J_SRC_DIR=${BASE_DIR} + +LIBND4J_BUILD_DIR=${BASE_DIR}/build_pi + +#for some downloads +XRTACT_STRIP="--strip-components=1" + +HAS_ARMCOMPUTE=1 +mkdir -p ${BASE_DIR} +mkdir -p ${THIRD_PARTY} + +#change directory to base +cd $BASE_DIR + +function message { + echo "BUILDER:::: ${@}" +} + + +function check_requirements { + for i in "${@}" + do + if [ ! -e "$i" ]; then + message "missing: ${i}" + exit -2 + fi + done +} + +function download_extract { + #$1 is url #2 is dir $3 is extract argument + if [ ! -f ${2}_file ]; then + message "download" + wget --quiet --show-progress -O ${2}_file ${1} + fi + + message "extract" + #extract + mkdir -p ${2} + command="tar -xzf ${2}_file --directory=${2} ${3} " + message $command + $command + + check_requirements "${2}" +} + +function git_check { + #$1 is url #$2 is dir #$3 is tag or branch if optional + command="git clone --quiet ${1} ${2}" + message "$command" + $command + if [ -n "$3" ]; then + cd ${2} + command="git checkout ${3}" + message "$command" + $command + cd ${BASE_DIR} + fi + check_requirements "${2}" +} + + +if [ ! -d ${CROSS_COMPILER_DIR} ]; then + #out file + message "download CROSS_COMPILER" + download_extract ${CROSS_COMPILER_URL} ${CROSS_COMPILER_DIR} ${XRTACT_STRIP} +fi + +#useful exports +export PI_FOLDER=${CROSS_COMPILER_DIR} +export RPI_BIN=${PI_FOLDER}/bin/arm-linux-gnueabihf +export PI_SYS_ROOT=${PI_FOLDER}/arm-linux-gnueabihf/libc +export LD_LIBRARY_PATH=${PI_FOLDER}/lib:$LD_LIBRARY_PATH +export CC=${RPI_BIN}-gcc +export FC=${RPI_BIN}-gfortran +export CXX=${RPI_BIN}-g++ +export CPP=${RPI_BIN}-cpp +export RANLIB=${RPI_BIN}-gcc-ranlib +export LD="${RPI_BIN}-ld" +export AR="${RPI_BIN}-ar" + + +#lets build OpenBlas +if [ ! -d "${OPENBLAS_DIR}" ]; then + message "download OpenBLAS" + git_check "${OPENBLAS_GIT_URL}" "${OPENBLAS_DIR}" +fi + +if [ ! -f "${THIRD_PARTY}/lib/libopenblas.so" ]; then + message "build and install OpenBLAS" + cd ${OPENBLAS_DIR} + + command="make TARGET=${BLAS_TARGET_NAME} HOSTCC=gcc CC=${CC} USE_THREAD=0 NOFORTRAN=1 CFLAGS=--sysroot=${PI_SYS_ROOT} LDFLAGS=\"-L${PI_SYS_ROOT}/../lib/ -lm\" &>/dev/null" + message $command + eval $command + message "install it" + command="make PREFIX=${THIRD_PARTY} install" + message $command + $command + cd $BASE_DIR + +fi +check_requirements ${THIRD_PARTY}/lib/libopenblas.so + + + +if [ ! -d ${SCONS_LOCAL_DIR} ]; then + #out file + message "download Scons local" + download_extract ${SCONS_LOCAL_URL} ${SCONS_LOCAL_DIR} +fi +check_requirements ${SCONS_LOCAL_DIR}/scons.py + + +if [ ! -d "${ARMCOMPUTE_DIR}" ]; then + message "download ArmCompute Source" + git_check ${ARMCOMPUTE_GIT_URL} "${ARMCOMPUTE_DIR}" "tags/${ARMCOMPUTE_TAG}" +fi + +#build armcompute +if [ ! -f "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" ]; then +message "build arm compute" +cd ${ARMCOMPUTE_DIR} +command="CC=gcc CXX=g++ python3 ${SCONS_LOCAL_DIR}/scons.py Werror=1 -j$(nproc) toolchain_prefix=${RPI_BIN}- debug=${ARMCOMPUTE_DEBUG} neon=1 opencl=0 extra_cxx_flags=-fPIC os=linux build=cross_compile arch=${ARMCOMPUTE_TARGET} &>/dev/null" +message $command +eval $command +cd ${BASE_DIR} +fi +check_requirements "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" "${ARMCOMPUTE_DIR}/build/libarm_compute_core-static.a" + + + +message "build cmake for LIBND4J. output: ${LIBND4J_BUILD_DIR}" + +TOOLCHAIN=${LIBND4J_SRC_DIR}/cmake/rpi.cmake +cmake_cmd="${CMAKE} -G \"Unix Makefiles\" -B${LIBND4J_BUILD_DIR} -S${LIBND4J_SRC_DIR} -DCMAKE_BUILD_TYPE=${LIBND4J_BUILD_MODE} -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN} -DCMAKE_VERBOSE_MAKEFILE:BOOL=ON -DSD_ALL_OPS=true -DSD_CPU=true -DSD_LIBRARY_NAME=nd4jcpu -DSD_BUILD_TESTS=ON -DSD_ARM_BUILD=true -DOPENBLAS_PATH=${THIRD_PARTY} -DSD_ARCH=${TARGET} -DARMCOMPUTE_ROOT=${ARMCOMPUTE_DIR} -DHELPERS_armcompute=${HAS_ARMCOMPUTE}" +message $cmake_cmd +eval $cmake_cmd + +#build +message "lets build" + +cd ${LIBND4J_BUILD_DIR} +make -j $(nproc) + + + + + + diff --git a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt index 563bf58f6..9478f6fe2 100644 --- a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt @@ -52,14 +52,19 @@ elseif(WIN32) set(CMAKE_CXX_FLAGS " -fPIC") endif() else() - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") set(CMAKE_CXX_FLAGS " -fPIC") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") + IF(${SD_ARCH} MATCHES "arm*") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=${SD_ARCH}") + else() + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*") set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native") else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native") endif() - + endif() if (SD_CPU AND SD_SANITIZE) set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address") else() @@ -130,7 +135,7 @@ if (SD_CPU) endif() add_executable(runtests ${TEST_SOURCES}) - target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} gtest gtest_main) + target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} ${ARMCOMPUTE_LIBRARIES} gtest gtest_main) elseif(SD_CUDA) add_executable(runtests ${TEST_SOURCES}) diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 169c51124..39277cd87 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -1113,7 +1113,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto z = result.at(0); - +#if 0 + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1132,7 +1135,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto z = result.at(0); - +#if 0 + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1151,7 +1157,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto z = result.at(0); - +#if 0 + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); } @@ -1204,7 +1213,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) { auto* output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); - +#if 0 + expOutput.printIndexedBuffer("expOutput"); + output->printIndexedBuffer("output"); +#endif ASSERT_TRUE(expOutput.isSameShape(output)); ASSERT_TRUE(expOutput.equalsTo(output)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index d3d1deed8..beccc1aae 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -244,7 +244,8 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) { #ifdef _RELEASE TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { // [2,1,135079944,1,1,8192,1,99] - auto initial = NDArrayFactory::create('c', {1, 135079944}); + constexpr int sizeX= 10*1000*1000; + auto initial = NDArrayFactory::create('c', {1, sizeX}); initial = 1.0f; auto exp = initial.dup(); auto neg = initial.like(); @@ -254,7 +255,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { auto enc_result = enc.evaluate({&initial}, {0.5f}); auto encoded = enc_result.at(1); - ASSERT_EQ(135079944 + 4, encoded->lengthOf()); + ASSERT_EQ(sizeX + 4, encoded->lengthOf()); ASSERT_NE(exp, initial); /* for (int e = 0; e < initial.lengthOf(); e++) { diff --git a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp b/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp deleted file mode 100644 index 8481dfde5..000000000 --- a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#ifndef LIBND4J_SESSIONLOCALTESTS_H -#define LIBND4J_SESSIONLOCALTESTS_H - -#include "testlayers.h" -#include -#include - -using namespace sd::graph; - -class SessionLocalTests : public testing::Test { -public: - -}; - -TEST_F(SessionLocalTests, BasicTests_1) { - VariableSpace variableSpace; - SessionLocalStorage storage(&variableSpace, nullptr); - - if (omp_get_max_threads() <= 1) - return; - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.startSession(); - } - - ASSERT_EQ(4, storage.numberOfSessions()); - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.endSession(); - } - - ASSERT_EQ(0, storage.numberOfSessions()); -} - - -TEST_F(SessionLocalTests, BasicTests_2) { - VariableSpace variableSpace; - SessionLocalStorage storage(&variableSpace, nullptr); - - if (omp_get_max_threads() <= 1) - return; - - auto alpha = sd::NDArrayFactory::create_('c',{5,5}); - alpha->assign(0.0); - - variableSpace.putVariable(-1, alpha); - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.startSession(); - - auto varSpace = storage.localVariableSpace(); - - auto arr = varSpace->getVariable(-1)->getNDArray(); - arr->applyScalar(sd::scalar::Add, (float) e+1, *arr); - } - - float lastValue = 0.0f; - for (int e = 1; e <= 4; e++) { - auto varSpace = storage.localVariableSpace((Nd4jLong) e); - - auto arr = varSpace->getVariable(-1)->getNDArray(); - - //nd4j_printf("Last value: %f; Current value: %f\n", lastValue, arr->e(0)); - - ASSERT_NE(lastValue, arr->e(0)); - lastValue = arr->e(0); - } -} - -#endif //LIBND4J_SESSIONLOCALTESTS_H diff --git a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt index 7e01e2847..bbd632d27 100644 --- a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt @@ -45,6 +45,21 @@ if ("${BUILD_MKLDNN}") set(MKLDNN dnnl) endif() +if (${HELPERS_armcompute}) + find_package(ARMCOMPUTE REQUIRED) + + if(ARMCOMPUTE_FOUND) + message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}") + set(HAVE_ARMCOMPUTE 1) + # Add preprocessor definition for ARM Compute NEON + add_definitions(-DARMCOMPUTENEON_ENABLED) + #build our library with neon support + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") + include_directories(${ARMCOMPUTE_INCLUDE}) + endif() + +endif() + # Download and unpack flatbuffers at configure time configure_file(../../CMakeLists.txt.in flatbuffers-download/CMakeLists.txt) execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . @@ -217,6 +232,10 @@ if ("${BUILD_MKLDNN}") file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../../include/ops/declarable/platform/mkldnn/*.cpp) endif() +if(HAVE_ARMCOMPUTE) + file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/armcomputeUtils.h) +endif() + message("CPU backend") add_definitions(-D__CPUBLAS__=true) @@ -276,8 +295,9 @@ endforeach(TMP_PATH) add_executable(runtests ${LOOPS_SOURCES} ${LEGACY_SOURCES} ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} - ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} + ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} + ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES}) -target_link_libraries(runtests gtest ${MKLDNN} gtest_main ${BLAS_LIBRARIES}) +target_link_libraries(runtests gtest ${MKLDNN} ${ARMCOMPUTE_LIBRARIES} gtest_main ${BLAS_LIBRARIES})