Merge pull request #9024 from KonduitAI/master

Development updates [WIP]
master
Alex Black 2020-06-26 19:34:28 +10:00 committed by GitHub
commit 43fd64358c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
79 changed files with 3443 additions and 693 deletions

View File

@ -57,10 +57,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.util.ModelSerializer; import org.deeplearning4j.util.ModelSerializer;
import org.junit.AfterClass; import org.junit.*;
import org.junit.Before; import org.junit.rules.TemporaryFolder;
import org.junit.BeforeClass;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
@ -82,6 +80,7 @@ import org.nd4j.common.resources.Resources;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream; import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.*; import java.util.*;
@ -91,6 +90,9 @@ import static org.junit.Assert.*;
@Slf4j @Slf4j
public class TestComputationGraphNetwork extends BaseDL4JTest { public class TestComputationGraphNetwork extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
private static ComputationGraphConfiguration getIrisGraphConfiguration() { private static ComputationGraphConfiguration getIrisGraphConfiguration() {
return new NeuralNetConfiguration.Builder().seed(12345) return new NeuralNetConfiguration.Builder().seed(12345)
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
@ -2177,4 +2179,40 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2); INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2);
cg.fit(new DataSet(in, label)); cg.fit(new DataSet(in, label));
} }
@Test
public void testMergeNchw() throws Exception {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.convolutionMode(ConvolutionMode.Same)
.graphBuilder()
.addInputs("in")
.layer("l0", new ConvolutionLayer.Builder()
.nOut(16)
.kernelSize(2,2).stride(1,1)
.build(), "in")
.layer("l1", new ConvolutionLayer.Builder()
.nOut(8)
.kernelSize(2,2).stride(1,1)
.build(), "in")
.addVertex("merge", new MergeVertex(), "l0", "l1")
.layer("out", new CnnLossLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge")
.setOutputs("out")
.setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC))
.build();
ComputationGraph cg = new ComputationGraph(conf);
cg.init();
INDArray[] in = new INDArray[]{Nd4j.rand(DataType.FLOAT, 1, 32, 32, 3)};
INDArray out = cg.outputSingle(in);
File dir = testDir.newFolder();
File f = new File(dir, "net.zip");
cg.save(f);
ComputationGraph c2 = ComputationGraph.load(f, true);
INDArray out2 = c2.outputSingle(in);
assertEquals(out, out2);
}
} }

View File

@ -66,8 +66,8 @@ import org.nd4j.shade.jackson.annotation.JsonProperty;
* @author Alex Black * @author Alex Black
*/ */
@Data @Data
@JsonIgnoreProperties({"mask", "helper", "helperCountFail"}) @JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"})
@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail"}) @EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"})
@Slf4j @Slf4j
public class Dropout implements IDropout { public class Dropout implements IDropout {

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.nn.conf.graph; package org.deeplearning4j.nn.conf.graph;
import lombok.Data;
import lombok.val; import lombok.val;
import org.deeplearning4j.nn.conf.CNN2DFormat; import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.RNNFormat; import org.deeplearning4j.nn.conf.RNNFormat;
@ -38,6 +39,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
* -> [numExamples,depth1 + depth2,width,height]}<br> * -> [numExamples,depth1 + depth2,width,height]}<br>
* @author Alex Black * @author Alex Black
*/ */
@Data
public class MergeVertex extends GraphVertex { public class MergeVertex extends GraphVertex {
protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format

View File

@ -141,7 +141,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm); SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats()); sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats());
System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); // System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
File f = testDir.newFolder(); File f = testDir.newFolder();
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345); DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
int count = 0; int count = 0;
@ -208,10 +208,10 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
} }
INDArray paramsAfter = after.params(); INDArray paramsAfter = after.params();
System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); // System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); // System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
System.out.println(Arrays.toString( // System.out.println(Arrays.toString(
Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); // Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
assertNotEquals(paramsBefore, paramsAfter); assertNotEquals(paramsBefore, paramsAfter);
@ -235,7 +235,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
} }
@Test @Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985
public void differentNetsTrainingTest() throws Exception { public void differentNetsTrainingTest() throws Exception {
int batch = 3; int batch = 3;

View File

@ -131,6 +131,23 @@ if(NOT SD_CUDA)
endif() endif()
endif() endif()
#arm-compute entry
if(${HELPERS_armcompute})
find_package(ARMCOMPUTE REQUIRED)
if(ARMCOMPUTE_FOUND)
message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}")
set(HAVE_ARMCOMPUTE 1)
# Add preprocessor definition for ARM Compute NEON
add_definitions(-DARMCOMPUTENEON_ENABLED)
#build our library with neon support
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon")
include_directories(${ARMCOMPUTE_INCLUDE})
message("----${ARMCOMPUTE_INCLUDE}---")
endif()
endif()
# new mkl-dnn entry # new mkl-dnn entry
if (${HELPERS_mkldnn}) if (${HELPERS_mkldnn})

View File

@ -146,6 +146,10 @@ if (HAVE_MKLDNN)
file(GLOB_RECURSE CUSTOMOPS_MKLDNN_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h) file(GLOB_RECURSE CUSTOMOPS_MKLDNN_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h)
endif() endif()
if(HAVE_ARMCOMPUTE)
file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/*.h)
endif()
if(SD_CUDA) if(SD_CUDA)
message("Build cublas") message("Build cublas")
find_package(CUDA) find_package(CUDA)
@ -243,7 +247,7 @@ if(SD_CUDA)
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
${CUSTOMOPS_GENERIC_SOURCES} ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
) )
if (WIN32) if (WIN32)
@ -351,8 +355,8 @@ elseif(SD_CPU)
add_definitions(-D__CPUBLAS__=true) add_definitions(-D__CPUBLAS__=true)
add_library(samediff_obj OBJECT ${LEGACY_SOURCES} add_library(samediff_obj OBJECT ${LEGACY_SOURCES}
${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
${OPS_SOURCES} ${PERF_SOURCES}) ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
if(IOS) if(IOS)
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>) add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
else() else()
@ -378,12 +382,12 @@ elseif(SD_CPU)
if (NOT BLAS_LIBRARIES) if (NOT BLAS_LIBRARIES)
set(BLAS_LIBRARIES "") set(BLAS_LIBRARIES "")
endif() endif()
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES}) target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}") if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
message(STATUS "Building minifier...") message(STATUS "Building minifier...")
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp) add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES}) target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
endif() endif()
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9) if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)

View File

@ -0,0 +1,74 @@
################################################################################
# Copyright (c) 2020 Konduit K.K.
#
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# SPDX-License-Identifier: Apache-2.0
################################################################################
### Find ARM COMPUTE LIBRARY STATIC libraries
SET (COMPUTE_INCLUDE_DIRS
/usr/include
${ARMCOMPUTE_ROOT}
${ARMCOMPUTE_ROOT}/include
${ARMCOMPUTE_ROOT}/applications
${ARMCOMPUTE_ROOT}/applications/arm_compute
)
SET (COMPUTE_LIB_DIRS
/lib
/usr/lib
${ARMCOMPUTE_ROOT}
${ARMCOMPUTE_ROOT}/lib
${ARMCOMPUTE_ROOT}/build
)
find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h
PATHS ${COMPUTE_INCLUDE_DIRS}
NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h)
find_path(HALF_INCLUDE half/half.hpp)
find_path(HALF_INCLUDE half/half.hpp
PATHS ${ARMCOMPUTE_ROOT}/include
NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
include_directories(SYSTEM ${HALF_INCLUDE})
# Find the Arm Compute libraries if not already specified
if (NOT DEFINED ARMCOMPUTE_LIBRARIES)
find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static
PATHS ${COMPUTE_LIB_DIRS}
PATH_SUFFIXES "Release"
NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static
PATHS ${COMPUTE_LIB_DIRS}
PATH_SUFFIXES "Release"
NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
# In case it wasn't there, try a default search (will work in cases where
# the library has been installed into a standard location)
find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static)
find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static)
set(ARMCOMPUTE_LIBRARIES ${ARMCOMPUTE_LIBRARY} ${ARMCOMPUTE_CORE_LIBRARY} )
endif()
INCLUDE(FindPackageHandleStandardArgs)
FIND_PACKAGE_HANDLE_STANDARD_ARGS(ARMCOMPUTE REQUIRED_VARS ARMCOMPUTE_INCLUDE ARMCOMPUTE_LIBRARIES)

View File

@ -69,7 +69,7 @@ namespace cnpy {
} }
}; };
struct ND4J_EXPORT npz_t : public std::unordered_map<std::string, NpyArray> { struct ND4J_EXPORT npz_t : public std::map<std::string, NpyArray> {
void destruct() { void destruct() {
npz_t::iterator it = this->begin(); npz_t::iterator it = this->begin();
for(; it != this->end(); ++it) (*it).second.destruct(); for(; it != this->end(); ++it) (*it).second.destruct();

View File

@ -3,6 +3,8 @@
#cmakedefine HAVE_MKLDNN #cmakedefine HAVE_MKLDNN
#cmakedefine HAVE_ARMCOMPUTE
#cmakedefine MKLDNN_PATH "@MKLDNN_PATH@" #cmakedefine MKLDNN_PATH "@MKLDNN_PATH@"
#cmakedefine HAVE_OPENBLAS #cmakedefine HAVE_OPENBLAS

View File

@ -45,18 +45,18 @@ namespace sd {
DECLARE_TYPES(max_pool_with_argmax) { DECLARE_TYPES(max_pool_with_argmax) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY) ->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes(0, DataType::INHERIT) ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS})
->setAllowedOutputTypes(1, {ALL_INTS}); ->setAllowedOutputTypes(1, {ALL_INDICES});
} }
DECLARE_SHAPE_FN(max_pool_with_argmax) { DECLARE_SHAPE_FN(max_pool_with_argmax) {
auto in = inputShape->at(0);
auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64;
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, dtype));
auto in = inputShape->at(0); return SHAPELIST(valuesShape, indicesShape);
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, DataType::INT64));
return SHAPELIST(valuesShape, indicesShape);
} }
} }
} }

View File

@ -215,7 +215,9 @@ namespace helpers {
auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]); auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]);
auto result = -1; auto result = -1;
//auto loop = PRAGMA_THREADS_FOR { //auto loop = PRAGMA_THREADS_FOR {
auto start = column, stop = rowNum, increment = 1; auto start = column;
auto stop = rowNum;
auto increment = 1;
for (auto rowCounter = start; rowCounter < stop; rowCounter++) { for (auto rowCounter = start; rowCounter < stop; rowCounter++) {
Nd4jLong xPos[] = {rowCounter, column}; Nd4jLong xPos[] = {rowCounter, column};
auto xIndex = shape::getOffset(compoundShape, xPos, 0); auto xIndex = shape::getOffset(compoundShape, xPos, 0);

View File

@ -73,7 +73,7 @@ namespace helpers {
} }
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) { void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES);
} }
} }

View File

@ -16,7 +16,8 @@
// //
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018 // @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
// // implementation is based on following article:
// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167
@ -31,96 +32,167 @@ namespace ops {
namespace helpers { namespace helpers {
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
// Fisher-Yates shuffle
template <typename T> template <typename T>
void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) {
for(Nd4jLong i = len-1; i > 0; --i) {
const Nd4jLong j = rng.relativeLong(ind++) % (i + 1);
if(i != j)
math::nd4j_swap<T>(buff[i*ews], buff[j*ews]);
}
}
//////////////////////////////////////////////////////////////////////////
// mutual shuffle of two adjacent already shuffled ranges with length len1 and (totLen - len1) correspondingly
template <typename T>
static void mergeShuffle(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len1, const Nd4jLong& totLen, const Nd4jLong& ews, Nd4jLong ind) {
Nd4jLong beg = 0; // beginning
Nd4jLong mid = len1; // middle
while (true) {
if(rng.relativeLong(ind++) % 2) {
if(mid == totLen)
break;
math::nd4j_swap<T>(buff[ews * beg], buff[ews * mid++]);
} else {
if(beg == mid)
break;
}
++beg;
}
// fisherYates
while (beg < totLen) {
const Nd4jLong j = rng.relativeLong(ind++) % (beg + 1);
if(beg != j)
math::nd4j_swap<T>(buff[ews * beg], buff[ews * j]);
++beg;
}
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
// check edge cases first
int temp;
const int firstDim = input.sizeAt(0); const int firstDim = input.sizeAt(0);
int temp;
if(input.lengthOf() == 1 || firstDim == 1) { if(input.lengthOf() == 1 || firstDim == 1) {
if(!isInplace) if(!isInplace)
output.assign(input); output.assign(input);
} }
else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) { else if (shape::isCommonVector(input.shapeInfo(), temp)) {
// apply Fisher-Yates shuffle NDArray* arr = &input;
if(isInplace) {
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold()) if (!isInplace) {
for(int i = firstDim-1; i > 0; --i) { output.assign(input);
int r = rng.relativeInt(i) % i; arr = &output;
if(i == r)
continue;
T t0 = input.t<T>(i);
T t1 = input.t<T>(r);
//math::nd4j_swap<T>(input(i), input(r));
input.r<T>(i) = t1;
input.r<T>(r) = t0;
}
} }
else {
std::vector<int> indices(firstDim);
std::iota(indices.begin(), indices.end(), 0);
output.p<T>(Nd4jLong(0), input.e<T>(0));
// FIXME: parallelism!! const Nd4jLong ews = arr->ews();
for(int i = firstDim-1; i > 0; --i) {
int r = rng.relativeInt(i) % i;
output.r<T>(i) = input.t<T>(indices[r]);
if(i == r)
continue;
output.r<T>(r) = input.t<T>(indices[i]); const Nd4jLong len = arr->lengthOf();
math::nd4j_swap<int>(indices[i], indices[r]); const Nd4jLong threshold = 1<<22; // this number was deduced from diagram in article
int power = 0;
while ((len >> power) > threshold)
++power;
const Nd4jLong numChunks = 1 << power;
auto funcFisherYates = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; ++i) {
Nd4jLong offset = (len * i) >> power;
Nd4jLong currLen = ((len * (i + 1)) >> power) - offset;
fisherYates<T>(rng, arr->bufferAsT<T>() + offset*ews, currLen, ews, offset);
} }
rng.rewindH(firstDim-1); };
}
auto funcMerge = PRAGMA_THREADS_FOR {
for (int64_t i = start, k = 1; i < stop; i += increment, ++k) {
Nd4jLong offset = len * i >> power;
Nd4jLong len1 = (len * (i + increment/2) >> power) - offset;
Nd4jLong totLen = (len * (i + increment) >> power) - offset;
mergeShuffle<T>(rng, arr->bufferAsT<T>() + offset*ews, len1, totLen, ews, len * k + offset);
}
};
samediff::Threads::parallel_for(funcFisherYates, 0, numChunks);
for (int j = 1; j < numChunks; j += j)
samediff::Threads::parallel_for(funcMerge, 0, numChunks, 2*j);
// #pragma omp parallel for
// for (uint i = 0; i < numChunks; ++i) {
// Nd4jLong offset = (len * i) >> power;
// Nd4jLong currLen = ((len * (i + 1)) >> power) - offset;
// fisherYates<T>(rng, arr->bufferAsT<T>() + offset*ews, currLen, ews, offset);
// }
// for (uint j = 1; j < numChunks; j += j) {
// #pragma omp parallel for
// for (auto i = 0; i < numChunks; i += 2*j) {
// Nd4jLong offset = len * i >> power;
// Nd4jLong len1 = (len * (i + j) >> power) - offset;
// Nd4jLong totLen = (len * (i + 2*j) >> power) - offset;
// mergeShuffle(rng, arr->bufferAsT<T>() + offset*ews, len1, totLen, ews, len * j + offset);
// }
// }
rng.rewindH((len + 1) * power);
} }
else { else {
// evaluate sub-arrays list of input array through all dimensions excluding first one auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
// apply Fisher-Yates shuffle
if(isInplace) { if(isInplace) {
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().elementwiseThreshold())
for(int i = firstDim - 1; i > 0; --i) {
int r = rng.relativeInt(i) % i;
if(i == r) auto subArrsList = input.allTensorsAlongDimension(dimsToExclude);
continue;
subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r)); // Fisher-Yates shuffle
for(int i = firstDim - 1; i > 0; --i) {
const int j = rng.relativeInt(i) % (i + 1);
if(i != j)
subArrsList.at(i)->swapUnsafe(*subArrsList.at(j));
} }
} }
else { else {
// evaluate sub-arrays list of output array through all dimensions excluding first one
auto subArrsListOut = output.allTensorsAlongDimension(dimensions); auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude);
auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude);
std::vector<int> indices(firstDim); std::vector<int> indices(firstDim);
std::iota(indices.begin(), indices.end(), 0); std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1
bool isZeroShuffled = false;
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold()) // shuffle indices
for(int i = firstDim - 1; i > 0; --i) { fisherYates<int>(rng, indices.data(), firstDim, 1, 0);
int r = rng.relativeInt(i) % i;
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r])); auto func = PRAGMA_THREADS_FOR {
if(r == 0)
isZeroShuffled = true; for (auto i = start; i < stop; ++i)
if(i == r) subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i]));
continue; };
subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i]));
math::nd4j_swap<int>(indices[i], indices[r]); samediff::Threads::parallel_for(func, 0, firstDim);
}
if(!isZeroShuffled)
subArrsListOut.at(0)->assign(subArrsListIn.at(0));
} }
rng.rewindH(firstDim-1); rng.rewindH(firstDim-1);
} }
} }
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) { void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES); BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
} }
} }
} }
} }

View File

@ -53,7 +53,7 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, const
int coords[MAX_RANK]; int coords[MAX_RANK];
for (uint64_t i = tid; i < zLen; i += totalThreads) { for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
shape::index2coords(i, zShapeInfo, coords); shape::index2coords(i, zShapeInfo, coords);
const auto zOffset = shape::getOffset(zShapeInfo, coords); const auto zOffset = shape::getOffset(zShapeInfo, coords);
@ -162,9 +162,9 @@ void concat(sd::LaunchContext * context, const std::vector<const NDArray*>& inAr
// } // }
// else { // general (slower) case // else { // general (slower) case
const int threadsPerBlock = 256; const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = 512; const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
const int sharedMem = 512; const int sharedMem = 256;
// prepare arrays of pointers on buffers and shapes // prepare arrays of pointers on buffers and shapes
std::vector<const void*> hInBuffers(numOfInArrs); std::vector<const void*> hInBuffers(numOfInArrs);

View File

@ -88,7 +88,7 @@ namespace helpers {
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) { void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
NDArray::prepareSpecialUse({values, indices}, {input}); NDArray::prepareSpecialUse({values, indices}, {input});
auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType(); auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType();
BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES); BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({values, indices}, {input}); NDArray::registerSpecialUse({values, indices}, {input});
} }

View File

@ -0,0 +1,228 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
// implemented algorithm is GPU adaptation of algorithm described in following article:
// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167
//
#include<ops/declarable/helpers/transforms.h>
#include <array/ResultSet.h>
#include <numeric>
#include <execution/Threads.h>
#include <helpers/ShapeUtils.h>
#include <helpers/PointersManager.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
template <typename T>
static __global__ void fisherYatesCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power) {
T* x = reinterpret_cast<T*>(vx);
__shared__ T* shmem, temp;
__shared__ Nd4jLong ind, blockOffset, lenPerBlock;
if (threadIdx.x == 0) {
extern __shared__ unsigned char sharedMemory[];
shmem = reinterpret_cast<T*>(sharedMemory);
blockOffset = (len * blockIdx.x) >> power;
lenPerBlock = ((len * (blockIdx.x + 1)) >> power) - blockOffset;
ind = blockOffset;
}
__syncthreads();
// copy from global memory to shared memory
if(threadIdx.x < lenPerBlock)
shmem[threadIdx.x] = x[(blockOffset + threadIdx.x) * ews];
__syncthreads();
// *** apply Fisher-Yates shuffle to lenPerBlock number of elements
if (threadIdx.x == 0) {
for(Nd4jLong i = lenPerBlock - 1; i > 0; --i) {
const Nd4jLong j = rng->relativeLong(ind++) % (i + 1);
if(i != j) {
temp = shmem[i];
shmem[i] = shmem[j];
shmem[j] = temp;
}
}
}
__syncthreads();
// copy from shared memory to global memory
if(threadIdx.x < lenPerBlock)
x[(blockOffset + threadIdx.x) * ews] = shmem[threadIdx.x];
}
template <typename T>
static __global__ void mergeShuffleCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power, const Nd4jLong iterNum) {
T* x = reinterpret_cast<T*>(vx);
__shared__ Nd4jLong ind, blockOffset, factor, beg, mid, totLen, iterExp;
// *** apply mergeShuffle algorithm
if(threadIdx.x == 0) {
factor = blockIdx.x << iterNum;
iterExp = 1 << (iterNum - 1);
blockOffset = (len * factor) >> power;
mid = ((len * (factor + iterExp)) >> power) - blockOffset; // middle
totLen = ((len * (factor + 2*iterExp)) >> power) - blockOffset;
ind = iterNum * len + blockOffset;
beg = 0; // beginning
// printf("m %lld, blockIdx.x %lld, factor %lld, blockOffset %lld, mid %lld, totLen %lld \n", m,k,factor,blockOffset,mid,totLen);
while (true) {
if(rng->relativeLong(ind++) % 2) {
if(mid == totLen)
break;
math::nd4j_swap<T>(x[(blockOffset + beg) * ews], x[(blockOffset + mid++) * ews]);
} else {
if(beg == mid)
break;
}
++beg;
}
// Fisher-Yates
while (beg < totLen) {
const Nd4jLong e = rng->relativeLong(ind++) % (beg + 1);
if(beg != e)
math::nd4j_swap<T>(x[(blockOffset + beg) * ews], x[(blockOffset + e) * ews]);
++beg;
}
}
}
//////////////////////////////////////////////////////////////////////////
// Fisher-Yates shuffle
template <typename T>
static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) {
for(Nd4jLong i = len-1; i > 0; --i) {
const Nd4jLong j = rng.relativeLong(ind++) % (i + 1);
if(i != j)
math::nd4j_swap<T>(buff[i*ews], buff[j*ews]);
}
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void randomShuffle_(sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
const int firstDim = input.sizeAt(0);
int temp;
if(input.lengthOf() == 1 || firstDim == 1) {
if(!isInplace)
output.assign(input);
}
else if (shape::isCommonVector(input.shapeInfo(), temp)) {
NDArray* arr = &input;
if (!isInplace) {
output.assign(input);
arr = &output;
}
const Nd4jLong len = arr->lengthOf();
const int threadsPerBlock = MAX_NUM_THREADS;
int power = 0;
while ((len >> power) > threadsPerBlock)
++power;
const int blocksPerGrid = 1 << power;
const int sharedMem = threadsPerBlock * input.sizeOfT() + 256;
PointersManager manager(context, "NDArray::randomShuffle cuda");
sd::graph::RandomGenerator* pRng = reinterpret_cast<sd::graph::RandomGenerator*>(manager.replicatePointer(&rng, sizeof(sd::graph::RandomGenerator)));
NDArray::prepareSpecialUse({arr}, {arr});
fisherYatesCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *context->getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power);
for (Nd4jLong j = 1, i = 1; j < blocksPerGrid; j += j, ++i)
mergeShuffleCuda<T><<<blocksPerGrid/(2*j), threadsPerBlock, 256, *context->getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power, i);
NDArray::registerSpecialUse({arr}, {arr});
manager.synchronize();
rng.rewindH((len + 1) * power);
}
else {
auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
if(isInplace) {
auto subArrsList = input.allTensorsAlongDimension(dimsToExclude);
// Fisher-Yates shuffle
for(int i = firstDim - 1; i > 0; --i) {
const int j = rng.relativeInt(i) % (i + 1);
if(i != j)
subArrsList.at(i)->swapUnsafe(*subArrsList.at(j));
}
}
else {
auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude);
auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude);
std::vector<int> indices(firstDim);
std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1
// shuffle indices
fisherYates<int>(rng, indices.data(), firstDim, 1, 0);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; ++i)
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i]));
};
samediff::Threads::parallel_for(func, 0, firstDim);
}
rng.rewindH(firstDim-1);
}
}
/////////////////////////////////////////////////////////////////////////
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
}
// BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
}
}
}

View File

@ -300,129 +300,6 @@ void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray
manager.synchronize(); manager.synchronize();
} }
template <typename T>
static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) {
auto tid = blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x;
for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
int r = rng->relativeInt(i) % i;
if (i != r) {
const auto iOffset = shape::getIndexOffset(i, shape);
const auto rOffset = shape::getIndexOffset(r, shape);
T e0 = input[iOffset];
T e1 = input[rOffset];
//math::nd4j_swap<T>(input(i), input(r));
input[iOffset] = e1;
input[rOffset] = e0;
}
}
}
template <typename T>
static __global__ void fillShuffleKernel(T* input, Nd4jLong const* inputShape, T* output, Nd4jLong const* outputShape, Nd4jLong firstDim, int* indices, sd::graph::RandomGenerator* rng) {
// PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
auto tid = blockIdx.x * blockDim.x;
auto step = blockDim.x * gridDim.x;
for(int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
int r = rng->relativeInt(i) % i;
output[shape::getIndexOffset(i, outputShape)] = input[shape::getIndexOffset(indices[r], inputShape)];
if(i != r) {
output[shape::getIndexOffset(r, outputShape)] = input[shape::getIndexOffset(indices[i], inputShape)];
// output.p(r, input.e<T>(indices[i]));
// math::nd4j_swap<int>(indices[i], indices[r]);
atomicExch(&indices[i], indices[r]);
}
}
}
//////////////////////////////////////////////////////////////////////////
template <typename T>
void randomShuffle_(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
// check edge cases first
int temp;
const int firstDim = input.sizeAt(0);
auto stream = context->getCudaStream();
NDArray::prepareSpecialUse({&output}, {&input});
if(input.lengthOf() == 1 || firstDim == 1) {
if(!isInplace)
output.assign(input);
}
else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) {
// apply Fisher-Yates shuffle
sd::graph::RandomGenerator* dRandom = nullptr;
cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator));
cudaMemcpy(dRandom, &rng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice);
T* inputBuf = reinterpret_cast<T*>(input.specialBuffer());
if(isInplace) {
swapShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), firstDim, dRandom);
}
else {
std::vector<int> indices(firstDim);
std::iota(indices.begin(), indices.end(), 0);
cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), cudaMemcpyDeviceToDevice);
//output.p<T>(Nd4jLong(0), input.e<T>(0));
PointersManager pointersManager(context, "helper::randomShuffle_");
int* indicesDev = reinterpret_cast<int*>(pointersManager.replicatePointer(indices.data(), indices.size() * sizeof(int)));
T* outputBuf = reinterpret_cast<T*>(output.specialBuffer());
fillShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), outputBuf, output.specialShapeInfo(), firstDim, indicesDev, dRandom);
pointersManager.synchronize();
}
// rng.rewindH(firstDim - 1);
cudaFree(dRandom);
}
else {
// evaluate sub-arrays list of input array through all dimensions excluding first one
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
// apply Fisher-Yates shuffle
if(isInplace) {
for(int i = firstDim - 1; i > 0; --i) {
int r = rng.relativeInt(i) % i;
if(i != r)
subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r));
}
}
else {
// evaluate sub-arrays list of output array through all dimensions excluding first one
auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
std::vector<int> indices(firstDim);
std::iota(indices.begin(), indices.end(), 0);
bool isZeroShuffled = false;
for(int i = firstDim - 1; i > 0; --i) {
int r = rng.relativeInt(i) % i;
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r]));
if(r == 0)
isZeroShuffled = true;
if(i != r) {
subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i]));
math::nd4j_swap<int>(indices[i], indices[r]);
}
}
if(!isZeroShuffled)
subArrsListOut.at(0)->assign(subArrsListIn.at(0));
}
rng.rewindH(firstDim-1);
}
NDArray::registerSpecialUse({&output}, {&input});
}
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////
void eye(sd::LaunchContext * context, NDArray& output) { void eye(sd::LaunchContext * context, NDArray& output) {

View File

@ -0,0 +1,278 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
// Created by Abdelrauf 2020
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <system/platform_boilerplate.h>
#include <ops/declarable/helpers/convolutions.h>
#include <cstdint>
#include <helpers/LoopsCoordsHelper.h>
#include "armcomputeUtils.h"
namespace sd {
namespace ops {
namespace platforms {
Arm_DataType getArmType ( const DataType &dType){
Arm_DataType ret;
switch (dType){
case HALF :
ret = Arm_DataType::F16;
break;
case FLOAT32 :
ret = Arm_DataType::F32;
break;
case DOUBLE :
ret = Arm_DataType::F64;
break;
case INT8 :
ret = Arm_DataType::S8;
break;
case INT16 :
ret = Arm_DataType::S16;
break;
case INT32 :
ret = Arm_DataType::S32;
break;
case INT64 :
ret = Arm_DataType::S64;
break;
case UINT8 :
ret = Arm_DataType::U8;
break;
case UINT16 :
ret = Arm_DataType::U16;
break;
case UINT32 :
ret = Arm_DataType::U32;
break;
case UINT64 :
ret = Arm_DataType::U64;
break;
case BFLOAT16 :
ret = Arm_DataType::BFLOAT16;
break;
default:
ret = Arm_DataType::UNKNOWN;
};
return ret;
}
bool isArmcomputeFriendly(const NDArray& arr) {
auto dType = getArmType(arr.dataType());
int rank = (int)(arr.rankOf());
return dType != Arm_DataType::UNKNOWN &&
rank<=arm_compute::MAX_DIMS &&
arr.ordering() == 'c' &&
arr.ews()==1 &&
shape::strideDescendingCAscendingF(arr.shapeInfo()) == true;
}
Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases,sd::DataType ndArrayType, arm_compute::DataLayout layout) {
constexpr int numChannels = 1;
auto dType = getArmType(ndArrayType);
Arm_TensorShape shape;
shape.set_num_dimensions(rank);
for (int i = 0, j = rank - 1; i < rank; i++, j--) {
shape[i] = static_cast<uint32_t>(bases[j]);
}
// fill the rest unused with 1
for (int i = rank; i < arm_compute::MAX_DIMS; i++) {
shape[i] = 1;
}
return Arm_TensorInfo(shape, numChannels, dType, layout);
}
Arm_TensorInfo getArmTensorInfo(const NDArray& arr,
arm_compute::DataLayout layout) {
auto dType = getArmType(arr.dataType());
//
constexpr int numChannels = 1;
int rank = (int)(arr.rankOf());
auto bases = arr.shapeOf();
auto arrStrides = arr.stridesOf();
// https://arm-software.github.io/ComputeLibrary/v20.05/_dimensions_8h_source.xhtml
// note: underhood it is stored as std::array<T, num_max_dimensions> _id;
// TensorShape is derived from Dimensions<uint32_t>
// as well as Strides : public Dimensions<uint32_t>
Arm_TensorShape shape;
Arm_Strides strides;
shape.set_num_dimensions(rank);
strides.set_num_dimensions(rank);
size_t element_size = arm_compute::data_size_from_type(dType);
for (int i = 0, j = rank - 1; i < rank; i++, j--) {
shape[i] = static_cast<uint32_t>(bases[j]);
strides[i] = static_cast<uint32_t>(arrStrides[j]) * element_size;
}
// fill the rest unused with 1
for (int i = rank; i < arm_compute::MAX_DIMS; i++) {
shape[i] = 1;
}
size_t total_size;
size_t size_ind = rank - 1;
total_size = shape[size_ind] * strides[size_ind];
Arm_TensorInfo info;
info.init(shape, numChannels, dType, strides, 0, total_size);
info.set_data_layout(layout);
return info;
}
Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout) {
// - Ownership of the backing memory is not transferred to the tensor itself.
// - The tensor mustn't be memory managed.
// - Padding requirements should be accounted by the client code.
// In other words, if padding is required by the tensor after the function
// configuration step, then the imported backing memory should account for it.
// Padding can be checked through the TensorInfo::padding() interface.
// Import existing pointer as backing memory
auto info = getArmTensorInfo(arr, layout);
Arm_Tensor tensor;
tensor.allocator()->init(info);
void* buff = (void*)arr.buffer();
tensor.allocator()->import_memory(buff);
return tensor;
}
void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output) {
//only for C order
//only for C order
if (output.ordering() != 'c') return;
auto shapeInfo = output.shapeInfo();
auto bases = &(shapeInfo[1]);
Nd4jLong rank = shapeInfo[0];
auto strides = output.stridesOf();
int width = bases[rank - 1];
uint8_t* outputBuffer = (uint8_t*)output.buffer();
size_t offset = 0;
arm_compute::Window window;
arm_compute::Iterator tensor_it(&inTensor, window);
int element_size = inTensor.info()->element_size();
window.use_tensor_dimensions(inTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY);
// if (output.ews() == 1) {
auto copySize = width * element_size;
auto dest = outputBuffer;
arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
{
auto src = tensor_it.ptr();
memcpy(dest, src, copySize);
dest += copySize;
},
tensor_it);
// }
// else {
// Nd4jLong coords[MAX_RANK] = {};
// if(strides[rank-1]!=1){
// throw std::runtime_error( "not implemented for subarrays whose last stride is not 1");
// //TODO: implement to work with all subarrays properly
// }
// arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
// {
// auto src = tensor_it.ptr();
// auto dest = outputBuffer + offset * element_size;
// memcpy(dest, src, width * element_size);
// offset = sd::inc_coords(bases, strides, coords, offset, rank, 1);
// },
// tensor_it);
// }
}
void copyToTensor(const NDArray& input, Arm_Tensor& outTensor) {
//only for C order
if (input.ordering() != 'c') return;
auto shapeInfo = input.shapeInfo();
auto bases = &(shapeInfo[1]);
Nd4jLong rank = shapeInfo[0];
auto strides = input.stridesOf();
uint8_t *inputBuffer = (uint8_t*)input.buffer();
int width = bases[rank - 1];
size_t offset = 0;
arm_compute::Window window;
arm_compute::Iterator tensor_it(&outTensor, window);
int element_size = outTensor.info()->element_size();
window.use_tensor_dimensions(outTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY);
// if (input.ews() == 1) {
auto copySize = width * element_size;
auto src = inputBuffer;
arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
{
auto dest = tensor_it.ptr();
memcpy(dest,src, copySize);
src += copySize;
},
tensor_it);
// }
// else {
// Nd4jLong coords[MAX_RANK] = {};
// if(strides[rank-1]!=1){
// throw std::runtime_error( "not implemented for subarrays whose last stride is not 1");
// //TODO: implement to work with all subarrays properly
// }
// arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
// {
// auto dest = tensor_it.ptr();
// auto src = inputBuffer + offset * element_size;
// offset = sd::inc_coords(bases, strides, coords, offset, rank, 1);
// },
// tensor_it);
// }
}
// armcompute should be built with debug option
void print_tensor(Arm_ITensor& tensor, const char* msg) {
auto info = tensor.info();
auto padding = info->padding();
std::cout << msg << "\ntotal: " << info->total_size() << "\n";
for (int i = 0; i < arm_compute::MAX_DIMS; i++) {
std::cout << info->dimension(i) << ",";
}
std::cout << std::endl;
for (int i = 0; i < arm_compute::MAX_DIMS; i++) {
std::cout << info->strides_in_bytes()[i] << ",";
}
std::cout << "\npadding: l " << padding.left << ", r " << padding.right
<< ", t " << padding.top << ", b " << padding.bottom << std::endl;
#ifdef ARM_COMPUTE_ASSERTS_ENABLED
//note it did not print correctly fro NHWC
std::cout << msg << ":\n";
tensor.print(std::cout);
std::cout << std::endl;
#endif
}
}
}
}

View File

@ -0,0 +1,133 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
#ifndef DEV_TESTSARMCOMPUTEUTILS_H
#define DEV_TESTSARMCOMPUTEUTILS_H
#include <legacy/NativeOps.h>
#include <array/NDArray.h>
#include <graph/Context.h>
#include <ops/declarable/PlatformHelper.h>
#include <system/platform_boilerplate.h>
#include <arm_compute/runtime/NEON/NEFunctions.h>
#include <arm_compute/core/Types.h>
#include <arm_compute/core/TensorInfo.h>
#include <arm_compute/core/TensorShape.h>
#include <arm_compute/core/Strides.h>
#include <arm_compute/core/Helpers.h>
#include <arm_compute/core/ITensor.h>
#include <arm_compute/core/Types.h>
#include <arm_compute/core/Validate.h>
#include <arm_compute/core/Window.h>
#include <arm_compute/runtime/Tensor.h>
#include <arm_compute/runtime/TensorAllocator.h>
#include <iostream>
using namespace samediff;
namespace sd {
namespace ops {
namespace platforms {
using Arm_DataType = arm_compute::DataType;
using Arm_Tensor = arm_compute::Tensor;
using Arm_ITensor = arm_compute::ITensor;
using Arm_TensorInfo = arm_compute::TensorInfo;
using Arm_TensorShape = arm_compute::TensorShape;
using Arm_Strides = arm_compute::Strides;
/**
* Here we actually declare our platform helpers
*/
DECLARE_PLATFORM(maxpool2d, ENGINE_CPU);
DECLARE_PLATFORM(avgpool2d, ENGINE_CPU);
//utils
Arm_DataType getArmType(const sd::DataType& dType);
Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases, sd::DataType ndArrayType, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN);
Arm_TensorInfo getArmTensorInfo(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN);
Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN);
void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output);
void copyToTensor(const NDArray& input, Arm_Tensor& outTensor);
void print_tensor(Arm_ITensor& tensor, const char* msg);
bool isArmcomputeFriendly(const NDArray& arr);
template<typename F>
class ArmFunction {
public:
template<typename ...Args>
void configure(NDArray *input , NDArray *output, arm_compute::DataLayout layout, Args&& ...args) {
auto inInfo = getArmTensorInfo(*input, layout);
auto outInfo = getArmTensorInfo(*output, layout);
in.allocator()->init(inInfo);
out.allocator()->init(outInfo);
armFunction.configure(&in,&out,std::forward<Args>(args) ...);
if (in.info()->has_padding()) {
//allocate and copy
in.allocator()->allocate();
//copy
copyToTensor(*input, in);
}
else {
//import buffer
void* buff = input->buffer();
in.allocator()->import_memory(buff);
}
if (out.info()->has_padding()) {
//store pointer to our array to copy after run
out.allocator()->allocate();
outNd = output;
}
else {
//import
void* buff = output->buffer();
out.allocator()->import_memory(buff);
}
}
void run() {
armFunction.run();
if (outNd) {
copyFromTensor(out, *outNd);
}
}
private:
Arm_Tensor in;
Arm_Tensor out;
NDArray *outNd=nullptr;
F armFunction{};
};
}
}
}
#endif //DEV_TESTSARMCOMPUTEUTILS_H

View File

@ -0,0 +1,106 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
// Created by Abdelrauf (rauf@konduit.ai) 2020
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <system/platform_boilerplate.h>
#include <ops/declarable/helpers/convolutions.h>
#include "armcomputeUtils.h"
namespace sd {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
const auto kH = INT_ARG(0);
const auto kW = INT_ARG(1);
const auto sH = INT_ARG(2);
const auto sW = INT_ARG(3);
auto pH = INT_ARG(4);
auto pW = INT_ARG(5);
const auto dH = INT_ARG(6);
const auto dW = INT_ARG(7);
const auto paddingMode = INT_ARG(8);
const auto extraParam0 = INT_ARG(9);
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D ARMCOMPUTE op: input should have rank of 4, but got %i instead", input->rankOf());
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D ARMCOMPUTE op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
bool exclude_padding= (extraParam0 == 0) ? true : false;
auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC;
// Calculate individual paddings
unsigned int pad_left, pad_top, pad_right, pad_bottom;
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
if(paddingMode){
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
}
pad_left = pW;
pad_top = pH;
pad_right = (oW - 1) * sW - iW + kW - pW ;
pad_bottom = (oH - 1) * sH - iH + kH - pH ;
#if 0
nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH
, pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0);
#endif
auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR);
auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::AVG, arm_compute::Size2D(kW, kH), dataLayout, poolPad, exclude_padding);
ArmFunction<arm_compute::NEPoolingLayer> pool;
pool.configure(input,output, dataLayout, poolInfo);
pool.run(); // run function
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const int dH = INT_ARG(6);
const int dW = INT_ARG(7);
// Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32
auto dTypeInput = getArmType(input->dataType());
auto dTypeOutput = getArmType(output->dataType());
bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output)
&& (dTypeInput ==Arm_DataType::F32)
&& (dTypeOutput ==Arm_DataType::F32);
return is_supported;
}
}
}
}

View File

@ -0,0 +1,106 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
// Created by Abdelrauf 2020
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <system/platform_boilerplate.h>
#include <ops/declarable/helpers/convolutions.h>
#include "armcomputeUtils.h"
namespace sd {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D ARMCOMPUTE OP: input array should have rank of 4, but got %i instead", input->rankOf());
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
const int kH = INT_ARG(0);
const int kW = INT_ARG(1);
const int sH = INT_ARG(2);
const int sW = INT_ARG(3);
int pH = INT_ARG(4);
int pW = INT_ARG(5);
const int dH = INT_ARG(6);
const int dW = INT_ARG(7);
const int paddingMode = INT_ARG(8);
// const int extraParam0 = INT_ARG(9);
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC;
// Calculate individual paddings
unsigned int pad_left, pad_top, pad_right, pad_bottom;
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
if(paddingMode){
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
}
pad_left = pW;
pad_top = pH;
pad_right = (oW - 1) * sW - iW + kW - pW ;
pad_bottom = (oH - 1) * sH - iH + kH - pH ;
#if 0
nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH
, pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0);
#endif
auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR);
auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::MAX, arm_compute::Size2D(kW, kH), dataLayout, poolPad);
ArmFunction<arm_compute::NEPoolingLayer> pool;
pool.configure(input,output, dataLayout, poolInfo);
pool.run(); // run function
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(maxpool2d, ENGINE_CPU) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const int dH = INT_ARG(6);
const int dW = INT_ARG(7);
// Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32
auto dTypeInput = getArmType(input->dataType());
auto dTypeOutput = getArmType(output->dataType());
bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output)
&& (dTypeInput ==Arm_DataType::F32)
&& (dTypeOutput ==Arm_DataType::F32);
return is_supported;
}
}
}
}

View File

@ -3963,9 +3963,6 @@ namespace simdOps {
} }
#endif #endif
#ifndef __clang__
#pragma omp declare simd uniform(extraParamsRef)
#endif
op_def static Y merge(X old, X opOutput, X *extraParamsRef) { op_def static Y merge(X old, X opOutput, X *extraParamsRef) {
return update(old, opOutput, extraParamsRef); return update(old, opOutput, extraParamsRef);
} }

185
libnd4j/pi_build.sh Executable file
View File

@ -0,0 +1,185 @@
#!/bin/bash
TARGET=armv7-a
BLAS_TARGET_NAME=ARMV7
ARMCOMPUTE_TARGET=armv7a
#BASE_DIR=${HOME}/pi
#https://stackoverflow.com/questions/59895/how-to-get-the-source-directory-of-a-bash-script-from-within-the-script-itself
SOURCE="${BASH_SOURCE[0]}"
ARMCOMPUTE_DEBUG=1
LIBND4J_BUILD_MODE=Release
while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )"
SOURCE="$(readlink "$SOURCE")"
[[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
done
BASE_DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )"
CMAKE=cmake #/snap/bin/cmake
mkdir -p ${BASE_DIR}/helper_bin/
CROSS_COMPILER_URL=https://sourceforge.net/projects/raspberry-pi-cross-compilers/files/Raspberry%20Pi%20GCC%20Cross-Compiler%20Toolchains/Buster/GCC%208.3.0/Raspberry%20Pi%203A%2B%2C%203B%2B%2C%204/cross-gcc-8.3.0-pi_3%2B.tar.gz/download
CROSS_COMPILER_DIR=${BASE_DIR}/helper_bin/cross_compiler
SCONS_LOCAL_URL=http://prdownloads.sourceforge.net/scons/scons-local-3.1.1.tar.gz
SCONS_LOCAL_DIR=${BASE_DIR}/helper_bin/scons_local
THIRD_PARTY=${BASE_DIR}/third_party_libs
ARMCOMPUTE_GIT_URL=https://github.com/ARM-software/ComputeLibrary.git
ARMCOMPUTE_TAG=v20.05
ARMCOMPUTE_DIR=${THIRD_PARTY}/arm_compute_dir
OPENBLAS_GIT_URL="https://github.com/xianyi/OpenBLAS.git"
OPENBLAS_DIR=${THIRD_PARTY}/OpenBLAS
LIBND4J_SRC_DIR=${BASE_DIR}
LIBND4J_BUILD_DIR=${BASE_DIR}/build_pi
#for some downloads
XRTACT_STRIP="--strip-components=1"
HAS_ARMCOMPUTE=1
mkdir -p ${BASE_DIR}
mkdir -p ${THIRD_PARTY}
#change directory to base
cd $BASE_DIR
function message {
echo "BUILDER:::: ${@}"
}
function check_requirements {
for i in "${@}"
do
if [ ! -e "$i" ]; then
message "missing: ${i}"
exit -2
fi
done
}
function download_extract {
#$1 is url #2 is dir $3 is extract argument
if [ ! -f ${2}_file ]; then
message "download"
wget --quiet --show-progress -O ${2}_file ${1}
fi
message "extract"
#extract
mkdir -p ${2}
command="tar -xzf ${2}_file --directory=${2} ${3} "
message $command
$command
check_requirements "${2}"
}
function git_check {
#$1 is url #$2 is dir #$3 is tag or branch if optional
command="git clone --quiet ${1} ${2}"
message "$command"
$command
if [ -n "$3" ]; then
cd ${2}
command="git checkout ${3}"
message "$command"
$command
cd ${BASE_DIR}
fi
check_requirements "${2}"
}
if [ ! -d ${CROSS_COMPILER_DIR} ]; then
#out file
message "download CROSS_COMPILER"
download_extract ${CROSS_COMPILER_URL} ${CROSS_COMPILER_DIR} ${XRTACT_STRIP}
fi
#useful exports
export PI_FOLDER=${CROSS_COMPILER_DIR}
export RPI_BIN=${PI_FOLDER}/bin/arm-linux-gnueabihf
export PI_SYS_ROOT=${PI_FOLDER}/arm-linux-gnueabihf/libc
export LD_LIBRARY_PATH=${PI_FOLDER}/lib:$LD_LIBRARY_PATH
export CC=${RPI_BIN}-gcc
export FC=${RPI_BIN}-gfortran
export CXX=${RPI_BIN}-g++
export CPP=${RPI_BIN}-cpp
export RANLIB=${RPI_BIN}-gcc-ranlib
export LD="${RPI_BIN}-ld"
export AR="${RPI_BIN}-ar"
#lets build OpenBlas
if [ ! -d "${OPENBLAS_DIR}" ]; then
message "download OpenBLAS"
git_check "${OPENBLAS_GIT_URL}" "${OPENBLAS_DIR}"
fi
if [ ! -f "${THIRD_PARTY}/lib/libopenblas.so" ]; then
message "build and install OpenBLAS"
cd ${OPENBLAS_DIR}
command="make TARGET=${BLAS_TARGET_NAME} HOSTCC=gcc CC=${CC} USE_THREAD=0 NOFORTRAN=1 CFLAGS=--sysroot=${PI_SYS_ROOT} LDFLAGS=\"-L${PI_SYS_ROOT}/../lib/ -lm\" &>/dev/null"
message $command
eval $command
message "install it"
command="make PREFIX=${THIRD_PARTY} install"
message $command
$command
cd $BASE_DIR
fi
check_requirements ${THIRD_PARTY}/lib/libopenblas.so
if [ ! -d ${SCONS_LOCAL_DIR} ]; then
#out file
message "download Scons local"
download_extract ${SCONS_LOCAL_URL} ${SCONS_LOCAL_DIR}
fi
check_requirements ${SCONS_LOCAL_DIR}/scons.py
if [ ! -d "${ARMCOMPUTE_DIR}" ]; then
message "download ArmCompute Source"
git_check ${ARMCOMPUTE_GIT_URL} "${ARMCOMPUTE_DIR}" "tags/${ARMCOMPUTE_TAG}"
fi
#build armcompute
if [ ! -f "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" ]; then
message "build arm compute"
cd ${ARMCOMPUTE_DIR}
command="CC=gcc CXX=g++ python3 ${SCONS_LOCAL_DIR}/scons.py Werror=1 -j$(nproc) toolchain_prefix=${RPI_BIN}- debug=${ARMCOMPUTE_DEBUG} neon=1 opencl=0 extra_cxx_flags=-fPIC os=linux build=cross_compile arch=${ARMCOMPUTE_TARGET} &>/dev/null"
message $command
eval $command
cd ${BASE_DIR}
fi
check_requirements "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" "${ARMCOMPUTE_DIR}/build/libarm_compute_core-static.a"
message "build cmake for LIBND4J. output: ${LIBND4J_BUILD_DIR}"
TOOLCHAIN=${LIBND4J_SRC_DIR}/cmake/rpi.cmake
cmake_cmd="${CMAKE} -G \"Unix Makefiles\" -B${LIBND4J_BUILD_DIR} -S${LIBND4J_SRC_DIR} -DCMAKE_BUILD_TYPE=${LIBND4J_BUILD_MODE} -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN} -DCMAKE_VERBOSE_MAKEFILE:BOOL=ON -DSD_ALL_OPS=true -DSD_CPU=true -DSD_LIBRARY_NAME=nd4jcpu -DSD_BUILD_TESTS=ON -DSD_ARM_BUILD=true -DOPENBLAS_PATH=${THIRD_PARTY} -DSD_ARCH=${TARGET} -DARMCOMPUTE_ROOT=${ARMCOMPUTE_DIR} -DHELPERS_armcompute=${HAS_ARMCOMPUTE}"
message $cmake_cmd
eval $cmake_cmd
#build
message "lets build"
cd ${LIBND4J_BUILD_DIR}
make -j $(nproc)

View File

@ -52,14 +52,19 @@ elseif(WIN32)
set(CMAKE_CXX_FLAGS " -fPIC") set(CMAKE_CXX_FLAGS " -fPIC")
endif() endif()
else() else()
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
set(CMAKE_CXX_FLAGS " -fPIC") set(CMAKE_CXX_FLAGS " -fPIC")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
IF(${SD_ARCH} MATCHES "arm*")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=${SD_ARCH}")
else()
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*") if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native") set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native")
else() else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native")
endif() endif()
endif()
if (SD_CPU AND SD_SANITIZE) if (SD_CPU AND SD_SANITIZE)
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address")
else() else()
@ -130,7 +135,7 @@ if (SD_CPU)
endif() endif()
add_executable(runtests ${TEST_SOURCES}) add_executable(runtests ${TEST_SOURCES})
target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} gtest gtest_main) target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} ${ARMCOMPUTE_LIBRARIES} gtest gtest_main)
elseif(SD_CUDA) elseif(SD_CUDA)
add_executable(runtests ${TEST_SOURCES}) add_executable(runtests ${TEST_SOURCES})

View File

@ -1113,7 +1113,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) {
ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0); auto z = result.at(0);
#if 0
exp.printIndexedBuffer("Expected");
z->printIndexedBuffer("Z");
#endif
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -1132,7 +1135,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) {
ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0); auto z = result.at(0);
#if 0
exp.printIndexedBuffer("Expected");
z->printIndexedBuffer("Z");
#endif
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
@ -1151,7 +1157,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) {
ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto z = result.at(0); auto z = result.at(0);
#if 0
exp.printIndexedBuffer("Expected");
z->printIndexedBuffer("Z");
#endif
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -1204,7 +1213,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) {
auto* output = results.at(0); auto* output = results.at(0);
ASSERT_EQ(Status::OK(), results.status()); ASSERT_EQ(Status::OK(), results.status());
#if 0
expOutput.printIndexedBuffer("expOutput");
output->printIndexedBuffer("output");
#endif
ASSERT_TRUE(expOutput.isSameShape(output)); ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output)); ASSERT_TRUE(expOutput.equalsTo(output));
} }

View File

@ -244,7 +244,8 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) {
#ifdef _RELEASE #ifdef _RELEASE
TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
// [2,1,135079944,1,1,8192,1,99] // [2,1,135079944,1,1,8192,1,99]
auto initial = NDArrayFactory::create<float>('c', {1, 135079944}); constexpr int sizeX= 10*1000*1000;
auto initial = NDArrayFactory::create<float>('c', {1, sizeX});
initial = 1.0f; initial = 1.0f;
auto exp = initial.dup(); auto exp = initial.dup();
auto neg = initial.like(); auto neg = initial.like();
@ -254,7 +255,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
auto enc_result = enc.evaluate({&initial}, {0.5f}); auto enc_result = enc.evaluate({&initial}, {0.5f});
auto encoded = enc_result.at(1); auto encoded = enc_result.at(1);
ASSERT_EQ(135079944 + 4, encoded->lengthOf()); ASSERT_EQ(sizeX + 4, encoded->lengthOf());
ASSERT_NE(exp, initial); ASSERT_NE(exp, initial);
/* /*
for (int e = 0; e < initial.lengthOf(); e++) { for (int e = 0; e < initial.lengthOf(); e++) {
@ -419,3 +420,4 @@ TEST_F(DeclarableOpsTests19, test_squeeze_1) {
auto status = op.execute({&x}, {&e}, {axis}); auto status = op.execute({&x}, {&e}, {axis});
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
} }

View File

@ -1557,8 +1557,6 @@ TEST_F(DeclarableOpsTests5, trace_test1) {
// exp.printIndexedBuffer("EXP TRACE"); // exp.printIndexedBuffer("EXP TRACE");
// output->printIndexedBuffer("OUT TRACE"); // output->printIndexedBuffer("OUT TRACE");
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1575,8 +1573,6 @@ TEST_F(DeclarableOpsTests5, trace_test2) {
ASSERT_EQ(Status::OK(), results.status()); ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1593,8 +1589,6 @@ TEST_F(DeclarableOpsTests5, trace_test3) {
ASSERT_EQ(Status::OK(), results.status()); ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1611,8 +1605,6 @@ TEST_F(DeclarableOpsTests5, trace_test4) {
ASSERT_EQ(Status::OK(), results.status()); ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1629,8 +1621,6 @@ TEST_F(DeclarableOpsTests5, trace_test5) {
ASSERT_EQ(Status::OK(), results.status()); ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1638,22 +1628,15 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) {
auto input = NDArrayFactory::create<double>('c', {2, 2, 2}); auto input = NDArrayFactory::create<double>('c', {2, 2, 2});
input.linspace(1); input.linspace(1);
NDArray exp1 = input.dup();
NDArray exp2('c',{2,2,2}, {5,6,7,8, 1,2,3,4}, sd::DataType::DOUBLE);
sd::ops::random_shuffle op; sd::ops::random_shuffle op;
auto results = op.evaluate({&input}); auto results = op.evaluate({&input});
auto output = results.at(0); auto output = results.at(0);
bool haveZeros = false;
for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.)
haveZeros = true;
ASSERT_EQ(Status::OK(), results.status()); ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(input.isSameShape(output)); ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2));
ASSERT_TRUE(!input.equalsTo(output));
ASSERT_TRUE(!haveZeros);
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1661,16 +1644,14 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test2) {
auto input = NDArrayFactory::create<double>('c', {1, 3, 2}); auto input = NDArrayFactory::create<double>('c', {1, 3, 2});
input.linspace(1); input.linspace(1);
NDArray exp1 = input.dup();
sd::ops::random_shuffle op; sd::ops::random_shuffle op;
auto results = op.evaluate({&input}); auto results = op.evaluate({&input});
auto output = results.at(0); auto output = results.at(0);
ASSERT_EQ(Status::OK(), results.status()); ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(input.isSameShape(output)); ASSERT_TRUE(output->equalsTo(exp1));
ASSERT_TRUE(input.equalsTo(output));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
@ -1678,129 +1659,132 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) {
auto input = NDArrayFactory::create<double>('c', {3, 2, 1}); auto input = NDArrayFactory::create<double>('c', {3, 2, 1});
input.linspace(1); input.linspace(1);
NDArray exp1 = input.dup();
NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE);
NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE);
NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE);
NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE);
NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE);
sd::ops::random_shuffle op; sd::ops::random_shuffle op;
auto results = op.evaluate({&input});
auto output = results.at(0);
bool haveZeros = false;
for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.)
haveZeros = true;
ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(input.isSameShape(output));
ASSERT_TRUE(!input.equalsTo(output));
ASSERT_TRUE(!haveZeros);
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test04) {
auto input = NDArrayFactory::create<double>('c', {4});
input.linspace(1);
sd::ops::random_shuffle op;
//NDArray* output;
auto results = op.evaluate({&input}, {}, {}, {}, {}, true); auto results = op.evaluate({&input}, {}, {}, {}, {}, true);
ASSERT_EQ(Status::OK(), results.status()); ASSERT_EQ(Status::OK(), results.status());
auto output = &input; //results.at(0); ASSERT_TRUE(input.equalsTo(exp1) || input.equalsTo(exp2) || input.equalsTo(exp3)
bool haveZeros = false; || input.equalsTo(exp4) || input.equalsTo(exp5) || input.equalsTo(exp6));
for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.)
haveZeros = true;
ASSERT_TRUE(input.isSameShape(output));
//ASSERT_TRUE(!input.equalsTo(output));
ASSERT_TRUE(!haveZeros);
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test4) { TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
auto input = NDArrayFactory::create<double>('c', {4});
auto input = NDArrayFactory::create<double>('c', {3, 2, 1});
input.linspace(1); input.linspace(1);
NDArray exp1 = input.dup();
NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE);
NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE);
NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE);
NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE);
NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE);
sd::ops::random_shuffle op; sd::ops::random_shuffle op;
//NDArray* output;
auto results = op.evaluate({&input}); auto results = op.evaluate({&input});
ASSERT_EQ(Status::OK(), results.status());
auto output = results.at(0); auto output = results.at(0);
bool haveZeros = false;
for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.)
haveZeros = true;
ASSERT_TRUE(input.isSameShape(output));
//ASSERT_TRUE(!input.equalsTo(output));
ASSERT_TRUE(!haveZeros);
ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2) || output->equalsTo(exp3)
|| output->equalsTo(exp4) || output->equalsTo(exp5) || output->equalsTo(exp6));
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test5) { TEST_F(DeclarableOpsTests5, random_shuffle_test5) {
auto input = NDArrayFactory::create<int>('c', {4});
auto input = NDArrayFactory::create<double>('c', {4,1});
input.linspace(1); input.linspace(1);
sd::ops::random_shuffle op; sd::ops::random_shuffle op;
auto results = op.evaluate({&input}); auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
auto output = results.at(0); auto output = results.at(0);
// output->printBuffer();
bool haveZeros = false;
for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.)
haveZeros = true;
ASSERT_EQ(Status::OK(), results.status()); ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(input.isSameShape(output)); // ASSERT_TRUE(!output->equalsTo(input));
ASSERT_TRUE(!input.equalsTo(output));
ASSERT_TRUE(!haveZeros);
bool hasDublicates = false;
for(int i = 0; i < output->lengthOf() - 1; ++i)
for(int j = i+1; j < output->lengthOf(); ++j)
if(output->t<int>(i) == output->t<int>(j)) {
hasDublicates = true;
i = output->lengthOf();
break;
}
ASSERT_TRUE(!hasDublicates);
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test6) { TEST_F(DeclarableOpsTests5, random_shuffle_test6) {
auto input = NDArrayFactory::create<int>('c', {4,1,1});
auto input = NDArrayFactory::create<double>('c', {4,1,1});
input.linspace(1); input.linspace(1);
sd::ops::random_shuffle op; sd::ops::random_shuffle op;
auto results = op.evaluate({&input}); auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
auto output = results.at(0); auto output = results.at(0);
bool haveZeros = false;
for(int i = 0; i < output->lengthOf(); ++i)
if(output->e<float>(i) == (float)0.)
haveZeros = true;
ASSERT_EQ(Status::OK(), results.status()); ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(input.isSameShape(output)); // ASSERT_TRUE(!output->equalsTo(input));
ASSERT_TRUE(!input.equalsTo(output));
ASSERT_TRUE(!haveZeros);
bool hasDublicates = false;
for(int i = 0; i < output->lengthOf() - 1; ++i)
for(int j = i+1; j < output->lengthOf(); ++j)
if(output->t<int>(i) == output->t<int>(j)) {
hasDublicates = true;
i = output->lengthOf();
break;
}
ASSERT_TRUE(!hasDublicates);
} }
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test7) { TEST_F(DeclarableOpsTests5, random_shuffle_test7) {
auto input = NDArrayFactory::create<int>('c', {16010});
auto input = NDArrayFactory::create<double>('c', {1,4});
input.linspace(1); input.linspace(1);
auto exp = NDArrayFactory::create<double>('c', {1,4}, {1, 2, 3, 4});
sd::ops::random_shuffle op; sd::ops::random_shuffle op;
auto results = op.evaluate({&input}); auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
auto output = results.at(0); auto output = results.at(0);
// output->printBuffer();
ASSERT_EQ(Status::OK(), results.status()); ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(input.isSameShape(output)); ASSERT_TRUE(!output->equalsTo(input));
ASSERT_TRUE(input.equalsTo(output));
auto vec1 = input.getBufferAsVector<int>();
auto vec2 = output->getBufferAsVector<int>();
std::sort(vec2.begin(), vec2.end());
ASSERT_TRUE(std::equal(vec1.begin(), vec1.end(), vec2.begin()));
}
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests5, random_shuffle_test8) {
auto input = NDArrayFactory::create<int>('c', {1,4,1});
input.linspace(1);
NDArray inCopy = input.dup();
sd::ops::random_shuffle op;
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
ASSERT_EQ(Status::OK(), results.status());
ASSERT_TRUE(input.equalsTo(inCopy));
}
TEST_F(DeclarableOpsTests5, random_shuffle_test9) {
auto x = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
auto z = x.ulike();
sd::ops::random_shuffle op;
auto status = op.execute({&x}, {&z});
ASSERT_EQ(Status::OK(), status);
auto vec = z.getBufferAsVector<int>();
std::sort(vec.begin(), vec.end());
ASSERT_EQ(std::vector<int>({1, 2, 3, 4}), vec);
} }
//////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////

View File

@ -251,11 +251,10 @@ TEST_F(DeclarableOpsTests9, concat_test1) {
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
ASSERT_EQ(ND4J_STATUS_OK, result.status()); ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto output = result.at(0); auto output = result.at(0);
// output->printCurrentBuffer<float>(false);
ASSERT_TRUE(exp.isSameShape(output)); ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output)); ASSERT_TRUE(exp.equalsTo(output));
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////

View File

@ -317,7 +317,7 @@ void fill_random(sd::NDArray& arr) {
} }
} }
void testLegacy(bool random) { void testLegacy(bool random) {
#if 0 #if 0
int bases[] = { 3, 2, 4, 5, 7 }; int bases[] = { 3, 2, 4, 5, 7 };
@ -364,7 +364,7 @@ int k = 4;
#endif #endif
auto dim = NDArrayFactory::create<int>(dimension); auto dim = NDArrayFactory::create<int>(dimension);
#if 1 #if 1
nd4j_printf("C(N:%d K:%d) \n", N, k); nd4j_printf("C(N:%d K:%d) \n", N, k);
dim.printIndexedBuffer("Dimension"); dim.printIndexedBuffer("Dimension");
for (int xind : dimension) { for (int xind : dimension) {
@ -385,7 +385,7 @@ for (int e = 0; e < Loop; e++) {
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count(); auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
values.emplace_back(outerTime); values.emplace_back(outerTime);
} }
std::sort(values.begin(), values.end()); std::sort(values.begin(), values.end());
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
@ -411,7 +411,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
constexpr int N = 5; constexpr int N = 5;
#endif #endif
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
arr_dimensions.push_back(bases[i]); arr_dimensions.push_back(bases[i]);
} }
@ -451,7 +451,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
#endif #endif
auto dim = NDArrayFactory::create<int>(dimension); auto dim = NDArrayFactory::create<int>(dimension);
#if 1 #if 1
nd4j_printf("C(N:%d K:%d) \n", N, k); nd4j_printf("C(N:%d K:%d) \n", N, k);
dim.printIndexedBuffer("Dimension"); dim.printIndexedBuffer("Dimension");
for (int xind : dimension) { for (int xind : dimension) {
@ -477,14 +477,14 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
//check for the correctness //check for the correctness
NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create<Nd4jLong>('c', output_bases) : NDArrayFactory::create<Nd4jLong>(0); NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create<Nd4jLong>('c', output_bases) : NDArrayFactory::create<Nd4jLong>(0);
original_argmax(x, dimension, exp); original_argmax(x, dimension, exp);
#if 0// defined(DEBUG) #if 0// defined(DEBUG)
x.printIndexedBuffer("X"); x.printIndexedBuffer("X");
exp.printIndexedBuffer("Expected"); exp.printIndexedBuffer("Expected");
z->printIndexedBuffer("Z"); z->printIndexedBuffer("Z");
#endif #endif
ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z)); ASSERT_TRUE(exp.equalsTo(z));
} }
@ -505,7 +505,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfLinspace) {
testNewReduction(false, test_corr); testNewReduction(false, test_corr);
} }
#endif #endif
TEST_F(PlaygroundTests, ArgMaxPerfRandom) { TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
testNewReduction(true, test_corr); testNewReduction(true, test_corr);
} }
@ -513,7 +513,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) { TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) {
testNewReduction(true, test_corr, 'f'); testNewReduction(true, test_corr, 'f');
} }
#if !defined(DEBUG) #if !defined(DEBUG)
TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) { TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) {
testLegacy(false); testLegacy(false);
@ -1062,39 +1062,6 @@ TEST_F(PlaygroundTests, my) {
delete variableSpace; delete variableSpace;
} }
TEST_F(PlaygroundTests, my) {
int N = 100;
int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
int oH=128,oW=128;
int paddingMode = 1; // 1-SAME, 0-VALID;
int dataFormat = 1; // 1-NHWC, 0-NCHW
// NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32);
// NDArray output('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32);
NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32);
NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32);
// NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
NDArray weights('c', {oC, iC, kH, kW}, sd::DataType::FLOAT32);
NDArray bias('c', {oC}, sd::DataType::FLOAT32);
input = 5.;
weights = 3.;
bias = 1.;
sd::ops::conv2d op;
auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto timeStart = std::chrono::system_clock::now();
for (int i = 0; i < N; ++i)
err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
auto timeEnd = std::chrono::system_clock::now();
auto time = std::chrono::duration_cast<std::chrono::microseconds> ((timeEnd - timeStart) / N).count();
printf("time: %i \n", time);
}
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(PlaygroundTests, lstmLayerCellBp_1) { TEST_F(PlaygroundTests, lstmLayerCellBp_1) {
@ -1690,6 +1657,52 @@ TEST_F(DeclarableOpsTests15, gru_bp_1) {
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
} }
#include<ops/declarable/helpers/transforms.h>
//////////////////////////////////////////////////////////////////////
TEST_F(PlaygroundTests, my) {
const int N = 10;
NDArray input('c', {8000000}, sd::DataType::INT32);
input.linspace(1);
NDArray output = input.dup();
sd::graph::RandomGenerator rng;
sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true);
// auto timeStart = std::chrono::system_clock::now();
// for (int i = 0; i < N; ++i)
// sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true);
// auto timeEnd = std::chrono::system_clock::now();
// auto time = std::chrono::duration_cast<std::chrono::microseconds> ((timeEnd - timeStart) / N).count();
// printf("time: %i \n", time);
// bool hasDublicates = false;
// for(int i = 0; i < output.lengthOf() - 1; ++i)
// for(int j = i+1; j < output.lengthOf(); ++j)
// if(output.t<int>(i) == output.t<int>(j)) {
// hasDublicates = true;
// i = output.lengthOf();
// break;
// }
ASSERT_TRUE(!input.equalsTo(output));
bool hasDublicates = false;
for(int i = 0; i < input.lengthOf() - 1; ++i)
for(int j = i+1; j < input.lengthOf(); ++j)
if(input.t<int>(i) == input.t<int>(j)) {
hasDublicates = true;
i = input.lengthOf();
break;
}
ASSERT_TRUE(!hasDublicates);
}
}
*/ */

View File

@ -1,93 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_SESSIONLOCALTESTS_H
#define LIBND4J_SESSIONLOCALTESTS_H
#include "testlayers.h"
#include <array/NDArrayFactory.h>
#include <graph/SessionLocalStorage.h>
using namespace sd::graph;
class SessionLocalTests : public testing::Test {
public:
};
TEST_F(SessionLocalTests, BasicTests_1) {
VariableSpace variableSpace;
SessionLocalStorage storage(&variableSpace, nullptr);
if (omp_get_max_threads() <= 1)
return;
PRAGMA_OMP_PARALLEL_FOR_THREADS(4)
for (int e = 0; e < 4; e++) {
storage.startSession();
}
ASSERT_EQ(4, storage.numberOfSessions());
PRAGMA_OMP_PARALLEL_FOR_THREADS(4)
for (int e = 0; e < 4; e++) {
storage.endSession();
}
ASSERT_EQ(0, storage.numberOfSessions());
}
TEST_F(SessionLocalTests, BasicTests_2) {
VariableSpace variableSpace;
SessionLocalStorage storage(&variableSpace, nullptr);
if (omp_get_max_threads() <= 1)
return;
auto alpha = sd::NDArrayFactory::create_<float>('c',{5,5});
alpha->assign(0.0);
variableSpace.putVariable(-1, alpha);
PRAGMA_OMP_PARALLEL_FOR_THREADS(4)
for (int e = 0; e < 4; e++) {
storage.startSession();
auto varSpace = storage.localVariableSpace();
auto arr = varSpace->getVariable(-1)->getNDArray();
arr->applyScalar(sd::scalar::Add, (float) e+1, *arr);
}
float lastValue = 0.0f;
for (int e = 1; e <= 4; e++) {
auto varSpace = storage.localVariableSpace((Nd4jLong) e);
auto arr = varSpace->getVariable(-1)->getNDArray();
//nd4j_printf("Last value: %f; Current value: %f\n", lastValue, arr->e(0));
ASSERT_NE(lastValue, arr->e<float>(0));
lastValue = arr->e<float>(0);
}
}
#endif //LIBND4J_SESSIONLOCALTESTS_H

View File

@ -45,6 +45,21 @@ if ("${BUILD_MKLDNN}")
set(MKLDNN dnnl) set(MKLDNN dnnl)
endif() endif()
if (${HELPERS_armcompute})
find_package(ARMCOMPUTE REQUIRED)
if(ARMCOMPUTE_FOUND)
message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}")
set(HAVE_ARMCOMPUTE 1)
# Add preprocessor definition for ARM Compute NEON
add_definitions(-DARMCOMPUTENEON_ENABLED)
#build our library with neon support
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon")
include_directories(${ARMCOMPUTE_INCLUDE})
endif()
endif()
# Download and unpack flatbuffers at configure time # Download and unpack flatbuffers at configure time
configure_file(../../CMakeLists.txt.in flatbuffers-download/CMakeLists.txt) configure_file(../../CMakeLists.txt.in flatbuffers-download/CMakeLists.txt)
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
@ -217,6 +232,10 @@ if ("${BUILD_MKLDNN}")
file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../../include/ops/declarable/platform/mkldnn/*.cpp) file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../../include/ops/declarable/platform/mkldnn/*.cpp)
endif() endif()
if(HAVE_ARMCOMPUTE)
file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/armcomputeUtils.h)
endif()
message("CPU backend") message("CPU backend")
add_definitions(-D__CPUBLAS__=true) add_definitions(-D__CPUBLAS__=true)
@ -276,8 +295,9 @@ endforeach(TMP_PATH)
add_executable(runtests ${LOOPS_SOURCES} ${LEGACY_SOURCES} ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} add_executable(runtests ${LOOPS_SOURCES} ${LEGACY_SOURCES} ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES}
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES}) ${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES})
target_link_libraries(runtests gtest ${MKLDNN} gtest_main ${BLAS_LIBRARIES}) target_link_libraries(runtests gtest ${MKLDNN} ${ARMCOMPUTE_LIBRARIES} gtest_main ${BLAS_LIBRARIES})

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -25,7 +25,7 @@
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
<groupId>org.eclipse</groupId> <groupId>org.nd4j</groupId>
<artifactId>python4j-parent</artifactId> <artifactId>python4j-parent</artifactId>
<packaging>pom</packaging> <packaging>pom</packaging>
<modules> <modules>
@ -41,10 +41,14 @@
<scope>provided</scope> <scope>provided</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.6.6</version>
</dependency> <dependency>
<groupId>ch.qos.logback</groupId> <groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId> <artifactId>logback-classic</artifactId>
<version>${logback.version}</version> <version>${logback.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency> <dependency>
<groupId>junit</groupId> <groupId>junit</groupId>
@ -62,5 +66,10 @@
<artifactId>jsr305</artifactId> <artifactId>jsr305</artifactId>
<version>3.0.2</version> <version>3.0.2</version>
</dependency> </dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.6.6</version>
</dependency>
</dependencies> </dependencies>
</project> </project>

View File

@ -21,7 +21,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent> <parent>
<artifactId>python4j-parent</artifactId> <artifactId>python4j-parent</artifactId>
<groupId>org.eclipse</groupId> <groupId>org.nd4j</groupId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
</parent> </parent>
<packaging>jar</packaging> <packaging>jar</packaging>
@ -39,6 +39,5 @@
<artifactId>cpython-platform</artifactId> <artifactId>cpython-platform</artifactId>
<version>${cpython-platform.version}</version> <version>${cpython-platform.version}</version>
</dependency> </dependency>
</dependencies> </dependencies>
</project> </project>

View File

@ -15,7 +15,7 @@
******************************************************************************/ ******************************************************************************/
package org.eclipse.python4j; package org.nd4j.python4j;
import org.bytedeco.cpython.PyObject; import org.bytedeco.cpython.PyObject;

View File

@ -14,13 +14,15 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.eclipse.python4j; package org.nd4j.python4j;
import javax.lang.model.SourceVersion; import javax.lang.model.SourceVersion;
import java.io.Closeable;
import java.util.HashSet; import java.util.HashSet;
import java.util.Set; import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
/** /**
@ -46,6 +48,31 @@ public class PythonContextManager {
init(); init();
} }
public static class Context implements Closeable{
private final String name;
private final String previous;
private final boolean temp;
public Context(){
name = "temp_" + UUID.randomUUID().toString().replace("-", "_");
temp = true;
previous = getCurrentContext();
setContext(name);
}
public Context(String name){
this.name = name;
temp = false;
previous = getCurrentContext();
setContext(name);
}
@Override
public void close(){
setContext(previous);
if (temp) deleteContext(name);
}
}
private static void init() { private static void init() {
if (init.get()) return; if (init.get()) return;
new PythonExecutioner(); new PythonExecutioner();
@ -76,7 +103,18 @@ public class PythonContextManager {
} }
private static boolean validateContextName(String s) { private static boolean validateContextName(String s) {
return SourceVersion.isIdentifier(s) && !s.startsWith(COLLAPSED_KEY); for (int i=0; i<s.length(); i++){
char c = s.toLowerCase().charAt(i);
if (i == 0){
if (c >= '0' && c <= '9'){
return false;
}
}
if (!(c=='_' || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9'))){
return false;
}
}
return true;
} }
private static String getContextPrefix(String contextName) { private static String getContextPrefix(String contextName) {
@ -190,6 +228,7 @@ public class PythonContextManager {
setContext(tempContext); setContext(tempContext);
deleteContext(currContext); deleteContext(currContext);
setContext(currContext); setContext(currContext);
deleteContext(tempContext);
} }
/** /**

View File

@ -14,7 +14,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.eclipse.python4j; package org.nd4j.python4j;
/** /**

View File

@ -15,7 +15,7 @@
******************************************************************************/ ******************************************************************************/
package org.eclipse.python4j; package org.nd4j.python4j;
import org.bytedeco.cpython.PyObject; import org.bytedeco.cpython.PyObject;
@ -42,7 +42,6 @@ public class PythonExecutioner {
private final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.eclipse.python4j.path"; private final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.eclipse.python4j.path";
private final static String JAVACPP_PYTHON_APPEND_TYPE = "org.eclipse.python4j.path.append"; private final static String JAVACPP_PYTHON_APPEND_TYPE = "org.eclipse.python4j.path.append";
private final static String DEFAULT_APPEND_TYPE = "before"; private final static String DEFAULT_APPEND_TYPE = "before";
static { static {
init(); init();
} }
@ -55,6 +54,11 @@ public class PythonExecutioner {
initPythonPath(); initPythonPath();
PyEval_InitThreads(); PyEval_InitThreads();
Py_InitializeEx(0); Py_InitializeEx(0);
for (PythonType type: PythonTypes.get()){
type.init();
}
// Constructors of custom types may contain initialization code that should
// run on the main the thread.
} }
/** /**
@ -110,6 +114,8 @@ public class PythonExecutioner {
getVariables(Arrays.asList(pyVars)); getVariables(Arrays.asList(pyVars));
} }
/** /**
* Gets the variable with the given name from the interpreter. * Gets the variable with the given name from the interpreter.
* *
@ -205,9 +211,9 @@ public class PythonExecutioner {
* *
* @return * @return
*/ */
public static List<PythonVariable> getAllVariables() { public static PythonVariables getAllVariables() {
PythonGIL.assertThreadSafe(); PythonGIL.assertThreadSafe();
List<PythonVariable> ret = new ArrayList<>(); PythonVariables ret = new PythonVariables();
PyObject main = PyImport_ImportModule("__main__"); PyObject main = PyImport_ImportModule("__main__");
PyObject globals = PyModule_GetDict(main); PyObject globals = PyModule_GetDict(main);
PyObject keys = PyDict_Keys(globals); PyObject keys = PyDict_Keys(globals);
@ -259,7 +265,7 @@ public class PythonExecutioner {
* @param inputs * @param inputs
* @return * @return
*/ */
public static List<PythonVariable> execAndReturnAllVariables(String code, List<PythonVariable> inputs) { public static PythonVariables execAndReturnAllVariables(String code, List<PythonVariable> inputs) {
setVariables(inputs); setVariables(inputs);
simpleExec(getWrappedCode(code)); simpleExec(getWrappedCode(code));
return getAllVariables(); return getAllVariables();
@ -271,7 +277,7 @@ public class PythonExecutioner {
* @param code * @param code
* @return * @return
*/ */
public static List<PythonVariable> execAndReturnAllVariables(String code) { public static PythonVariables execAndReturnAllVariables(String code) {
simpleExec(getWrappedCode(code)); simpleExec(getWrappedCode(code));
return getAllVariables(); return getAllVariables();
} }
@ -279,25 +285,22 @@ public class PythonExecutioner {
private static synchronized void initPythonPath() { private static synchronized void initPythonPath() {
try { try {
String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY); String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY);
List<File> packagesList = new ArrayList<>();
packagesList.addAll(Arrays.asList(cachePackages()));
for (PythonType type: PythonTypes.get()){
packagesList.addAll(Arrays.asList(type.packages()));
}
//// TODO: fix in javacpp
packagesList.add(new File(python.cachePackage(), "site-packages"));
File[] packages = packagesList.toArray(new File[0]);
if (path == null) { if (path == null) {
File[] packages = cachePackages();
//// TODO: fix in javacpp
File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
File[] packages2 = new File[packages.length + 1];
for (int i = 0; i < packages.length; i++) {
//System.out.println(packages[i].getAbsolutePath());
packages2[i] = packages[i];
}
packages2[packages.length] = sitePackagesWindows;
//System.out.println(sitePackagesWindows.getAbsolutePath());
packages = packages2;
//////////
Py_SetPath(packages); Py_SetPath(packages);
} else { } else {
StringBuffer sb = new StringBuffer(); StringBuffer sb = new StringBuffer();
File[] packages = cachePackages();
JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase()); JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase());
switch (pathAppendValue) { switch (pathAppendValue) {
case BEFORE: case BEFORE:

View File

@ -15,7 +15,7 @@
******************************************************************************/ ******************************************************************************/
package org.eclipse.python4j; package org.nd4j.python4j;
import org.bytedeco.cpython.PyObject; import org.bytedeco.cpython.PyObject;
import org.bytedeco.javacpp.Pointer; import org.bytedeco.javacpp.Pointer;

View File

@ -14,11 +14,10 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.eclipse.python4j; package org.nd4j.python4j;
import org.bytedeco.cpython.PyThreadState; import org.bytedeco.cpython.PyThreadState;
import org.omg.SendingContext.RunTime;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
@ -90,4 +89,8 @@ public class PythonGIL implements AutoCloseable {
PyEval_SaveThread(); PyEval_SaveThread();
PyEval_RestoreThread(mainThreadState); PyEval_RestoreThread(mainThreadState);
} }
public static boolean locked(){
return acquired.get();
}
} }

View File

@ -14,31 +14,34 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.eclipse.python4j; package org.nd4j.python4j;
import lombok.Builder; import lombok.Builder;
import lombok.Data; import lombok.Data;
import lombok.NoArgsConstructor; import lombok.extern.slf4j.Slf4j;
import javax.annotation.Nonnull; import javax.annotation.Nonnull;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
@Data
@NoArgsConstructor
/** /**
* PythonJob is the right abstraction for executing multiple python scripts * PythonJob is the right abstraction for executing multiple python scripts
* in a multi thread stateful environment. The setup-and-run mode allows your * in a multi thread stateful environment. The setup-and-run mode allows your
* "setup" code (imports, model loading etc) to be executed only once. * "setup" code (imports, model loading etc) to be executed only once.
*/ */
@Data
@Slf4j
public class PythonJob { public class PythonJob {
private String code; private String code;
private String name; private String name;
private String context; private String context;
private boolean setupRunMode; private final boolean setupRunMode;
private PythonObject runF; private PythonObject runF;
private final AtomicBoolean setupDone = new AtomicBoolean(false);
static { static {
new PythonExecutioner(); new PythonExecutioner();
@ -63,7 +66,6 @@ public class PythonJob {
if (PythonContextManager.hasContext(context)) { if (PythonContextManager.hasContext(context)) {
throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!"); throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!");
} }
if (setupRunMode) setup();
} }
@ -71,17 +73,18 @@ public class PythonJob {
* Clears all variables in current context and calls setup() * Clears all variables in current context and calls setup()
*/ */
public void clearState(){ public void clearState(){
String context = this.context; PythonContextManager.setContext(this.context);
PythonContextManager.setContext("main"); PythonContextManager.reset();
PythonContextManager.deleteContext(context); setupDone.set(false);
this.context = context;
setup(); setup();
} }
public void setup(){ public void setup(){
if (setupDone.get()) return;
try (PythonGIL gil = PythonGIL.lock()) { try (PythonGIL gil = PythonGIL.lock()) {
PythonContextManager.setContext(context); PythonContextManager.setContext(context);
PythonObject runF = PythonExecutioner.getVariable("run"); PythonObject runF = PythonExecutioner.getVariable("run");
if (runF == null || runF.isNone() || !Python.callable(runF)) { if (runF == null || runF.isNone() || !Python.callable(runF)) {
PythonExecutioner.exec(code); PythonExecutioner.exec(code);
runF = PythonExecutioner.getVariable("run"); runF = PythonExecutioner.getVariable("run");
@ -98,10 +101,12 @@ public class PythonJob {
if (!setupF.isNone()) { if (!setupF.isNone()) {
setupF.call(); setupF.call();
} }
setupDone.set(true);
} }
} }
public void exec(List<PythonVariable> inputs, List<PythonVariable> outputs) { public void exec(List<PythonVariable> inputs, List<PythonVariable> outputs) {
if (setupRunMode)setup();
try (PythonGIL gil = PythonGIL.lock()) { try (PythonGIL gil = PythonGIL.lock()) {
try (PythonGC _ = PythonGC.watch()) { try (PythonGC _ = PythonGC.watch()) {
PythonContextManager.setContext(context); PythonContextManager.setContext(context);
@ -139,6 +144,7 @@ public class PythonJob {
} }
public List<PythonVariable> execAndReturnAllVariables(List<PythonVariable> inputs){ public List<PythonVariable> execAndReturnAllVariables(List<PythonVariable> inputs){
if (setupRunMode)setup();
try (PythonGIL gil = PythonGIL.lock()) { try (PythonGIL gil = PythonGIL.lock()) {
try (PythonGC _ = PythonGC.watch()) { try (PythonGC _ = PythonGC.watch()) {
PythonContextManager.setContext(context); PythonContextManager.setContext(context);

View File

@ -14,7 +14,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.eclipse.python4j; package org.nd4j.python4j;
import org.bytedeco.cpython.PyObject; import org.bytedeco.cpython.PyObject;
@ -147,7 +147,8 @@ public class PythonObject {
} }
PythonObject pyArgs; PythonObject pyArgs;
PythonObject pyKwargs; PythonObject pyKwargs;
if (args == null) {
if (args == null || args.isEmpty()) {
pyArgs = new PythonObject(PyTuple_New(0)); pyArgs = new PythonObject(PyTuple_New(0));
} else { } else {
PythonObject argsList = PythonTypes.convert(args); PythonObject argsList = PythonTypes.convert(args);
@ -158,6 +159,7 @@ public class PythonObject {
} else { } else {
pyKwargs = PythonTypes.convert(kwargs); pyKwargs = PythonTypes.convert(kwargs);
} }
PythonObject ret = new PythonObject( PythonObject ret = new PythonObject(
PyObject_Call( PyObject_Call(
nativePythonObject, nativePythonObject,
@ -165,7 +167,9 @@ public class PythonObject {
pyKwargs == null ? null : pyKwargs.nativePythonObject pyKwargs == null ? null : pyKwargs.nativePythonObject
) )
); );
PythonGC.keep(ret); PythonGC.keep(ret);
return ret; return ret;
} }
@ -241,4 +245,48 @@ public class PythonObject {
PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject); PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject);
} }
public PythonObject abs(){
return new PythonObject(PyNumber_Absolute(nativePythonObject));
}
public PythonObject add(PythonObject pythonObject){
return new PythonObject(PyNumber_Add(nativePythonObject, pythonObject.nativePythonObject));
}
public PythonObject sub(PythonObject pythonObject){
return new PythonObject(PyNumber_Subtract(nativePythonObject, pythonObject.nativePythonObject));
}
public PythonObject mod(PythonObject pythonObject){
return new PythonObject(PyNumber_Divmod(nativePythonObject, pythonObject.nativePythonObject));
}
public PythonObject mul(PythonObject pythonObject){
return new PythonObject(PyNumber_Multiply(nativePythonObject, pythonObject.nativePythonObject));
}
public PythonObject trueDiv(PythonObject pythonObject){
return new PythonObject(PyNumber_TrueDivide(nativePythonObject, pythonObject.nativePythonObject));
}
public PythonObject floorDiv(PythonObject pythonObject){
return new PythonObject(PyNumber_FloorDivide(nativePythonObject, pythonObject.nativePythonObject));
}
public PythonObject matMul(PythonObject pythonObject){
return new PythonObject(PyNumber_MatrixMultiply(nativePythonObject, pythonObject.nativePythonObject));
}
public void addi(PythonObject pythonObject){
PyNumber_InPlaceAdd(nativePythonObject, pythonObject.nativePythonObject);
}
public void subi(PythonObject pythonObject){
PyNumber_InPlaceSubtract(nativePythonObject, pythonObject.nativePythonObject);
}
public void muli(PythonObject pythonObject){
PyNumber_InPlaceMultiply(nativePythonObject, pythonObject.nativePythonObject);
}
public void trueDivi(PythonObject pythonObject){
PyNumber_InPlaceTrueDivide(nativePythonObject, pythonObject.nativePythonObject);
}
public void floorDivi(PythonObject pythonObject){
PyNumber_InPlaceFloorDivide(nativePythonObject, pythonObject.nativePythonObject);
}
public void matMuli(PythonObject pythonObject){
PyNumber_InPlaceMatrixMultiply(nativePythonObject, pythonObject.nativePythonObject);
}
} }

View File

@ -0,0 +1,127 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.python4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
public class PythonProcess {
private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
ProcessBuilder pb = new ProcessBuilder(allArgs);
Process process = pb.start();
String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
process.waitFor();
return out;
}
public static void run(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
ProcessBuilder pb = new ProcessBuilder(allArgs);
pb.inheritIO().start().waitFor();
}
public static void pipInstall(String packageName) throws PythonException{
try{
run("-m", "pip", "install", packageName);
}catch(Exception e){
throw new PythonException("Error installing package " + packageName, e);
}
}
public static void pipInstall(String packageName, String version){
pipInstall(packageName + "==" + version);
}
public static void pipUninstall(String packageName) throws PythonException{
try{
run("-m", "pip", "uninstall", packageName);
}catch(Exception e){
throw new PythonException("Error uninstalling package " + packageName, e);
}
}
public static void pipInstallFromGit(String gitRepoUrl){
if (!gitRepoUrl.contains("://")){
gitRepoUrl = "git://" + gitRepoUrl;
}
try{
run("-m", "pip", "install", "git+", gitRepoUrl);
}catch(Exception e){
throw new PythonException("Error installing package from " + gitRepoUrl, e);
}
}
public static String getPackageVersion(String packageName){
String out;
try{
out = runAndReturn("-m", "pip", "show", packageName);
} catch (Exception e){
throw new PythonException("Error finding version for package " + packageName, e);
}
if (!out.contains("Version: ")){
throw new PythonException("Can't find package " + packageName);
}
String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
return pkgVersion;
}
public static boolean isPackageInstalled(String packageName){
try{
String out = runAndReturn("-m", "pip", "show", packageName);
return !out.isEmpty();
}catch (Exception e){
throw new PythonException("Error checking if package is installed: " +packageName, e);
}
}
public static void pipInstallFromRequirementsTxt(String path){
try{
run("-m", "pip", "install","-r", path);
}catch (Exception e){
throw new PythonException("Error installing packages from " + path, e);
}
}
public static void pipInstallFromSetupScript(String path, boolean inplace){
try{
run(path, inplace?"develop":"install");
}catch (Exception e){
throw new PythonException("Error installing package from " + path, e);
}
}
}

View File

@ -14,9 +14,11 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.eclipse.python4j; package org.nd4j.python4j;
import java.io.File;
public abstract class PythonType<T> { public abstract class PythonType<T> {
private final String name; private final String name;
@ -43,5 +45,25 @@ public abstract class PythonType<T> {
return name; return name;
} }
@Override
public boolean equals(Object obj){
if (!(obj instanceof PythonType)){
return false;
}
PythonType other = (PythonType)obj;
return this.getClass().equals(other.getClass()) && this.name.equals(other.name);
}
public PythonObject pythonType(){
return null;
}
public File[] packages(){
return new File[0];
}
public void init(){ //not to be called from constructor
}
} }

View File

@ -14,11 +14,18 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.eclipse.python4j; package org.nd4j.python4j;
import org.bytedeco.cpython.PyObject; import org.bytedeco.cpython.PyObject;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import sun.nio.ch.DirectBuffer;
import java.lang.reflect.Field;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*; import java.util.*;
import static org.bytedeco.cpython.global.python.*; import static org.bytedeco.cpython.global.python.*;
@ -28,7 +35,7 @@ public class PythonTypes {
private static List<PythonType> getPrimitiveTypes() { private static List<PythonType> getPrimitiveTypes() {
return Arrays.<PythonType>asList(STR, INT, FLOAT, BOOL); return Arrays.<PythonType>asList(STR, INT, FLOAT, BOOL, BYTES);
} }
private static List<PythonType> getCollectionTypes() { private static List<PythonType> getCollectionTypes() {
@ -36,8 +43,13 @@ public class PythonTypes {
} }
private static List<PythonType> getExternalTypes() { private static List<PythonType> getExternalTypes() {
//TODO service loader List<PythonType> ret = new ArrayList<>();
return new ArrayList<>(); ServiceLoader<PythonType> sl = ServiceLoader.load(PythonType.class);
Iterator<PythonType> iter = sl.iterator();
while (iter.hasNext()) {
ret.add(iter.next());
}
return ret;
} }
public static List<PythonType> get() { public static List<PythonType> get() {
@ -48,15 +60,17 @@ public class PythonTypes {
return ret; return ret;
} }
public static PythonType get(String name) { public static <T> PythonType<T> get(String name) {
for (PythonType pt : get()) { for (PythonType pt : get()) {
if (pt.getName().equals(name)) { // TODO use map instead? if (pt.getName().equals(name)) { // TODO use map instead?
return pt; return pt;
} }
} }
throw new PythonException("Unknown python type: " + name); throw new PythonException("Unknown python type: " + name);
} }
public static PythonType getPythonTypeForJavaObject(Object javaObject) { public static PythonType getPythonTypeForJavaObject(Object javaObject) {
for (PythonType pt : get()) { for (PythonType pt : get()) {
if (pt.accepts(javaObject)) { if (pt.accepts(javaObject)) {
@ -66,7 +80,7 @@ public class PythonTypes {
throw new PythonException("Unable to find python type for java type: " + javaObject.getClass()); throw new PythonException("Unable to find python type for java type: " + javaObject.getClass());
} }
public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) { public static <T> PythonType<T> getPythonTypeForPythonObject(PythonObject pythonObject) {
PyObject pyType = PyObject_Type(pythonObject.getNativePythonObject()); PyObject pyType = PyObject_Type(pythonObject.getNativePythonObject());
try { try {
String pyTypeStr = PythonTypes.STR.toJava(new PythonObject(pyType, false)); String pyTypeStr = PythonTypes.STR.toJava(new PythonObject(pyType, false));
@ -75,6 +89,14 @@ public class PythonTypes {
String pyTypeStr2 = "<class '" + pt.getName() + "'>"; String pyTypeStr2 = "<class '" + pt.getName() + "'>";
if (pyTypeStr.equals(pyTypeStr2)) { if (pyTypeStr.equals(pyTypeStr2)) {
return pt; return pt;
} else {
try (PythonGC gc = PythonGC.watch()) {
PythonObject pyType2 = pt.pythonType();
if (pyType2 != null && Python.isinstance(pythonObject, pyType2)) {
return pt;
}
}
} }
} }
throw new PythonException("Unable to find converter for python object of type " + pyTypeStr); throw new PythonException("Unable to find converter for python object of type " + pyTypeStr);
@ -212,12 +234,53 @@ public class PythonTypes {
public static final PythonType<List> LIST = new PythonType<List>("list", List.class) { public static final PythonType<List> LIST = new PythonType<List>("list", List.class) {
@Override
public boolean accepts(Object javaObject) {
return (javaObject instanceof List || javaObject.getClass().isArray());
}
@Override @Override
public List adapt(Object javaObject) { public List adapt(Object javaObject) {
if (javaObject instanceof List) { if (javaObject instanceof List) {
return (List) javaObject; return (List) javaObject;
} else if (javaObject instanceof Object[]) { } else if (javaObject.getClass().isArray()) {
return Arrays.asList((Object[]) javaObject); List<Object> ret = new ArrayList<>();
if (javaObject instanceof Object[]) {
Object[] arr = (Object[]) javaObject;
return new ArrayList<>(Arrays.asList(arr));
} else if (javaObject instanceof short[]) {
short[] arr = (short[]) javaObject;
for (short x : arr) ret.add(x);
return ret;
} else if (javaObject instanceof int[]) {
int[] arr = (int[]) javaObject;
for (int x : arr) ret.add(x);
return ret;
}else if (javaObject instanceof byte[]){
byte[] arr = (byte[]) javaObject;
for (int x : arr) ret.add(x);
return ret;
} else if (javaObject instanceof long[]) {
long[] arr = (long[]) javaObject;
for (long x : arr) ret.add(x);
return ret;
} else if (javaObject instanceof float[]) {
float[] arr = (float[]) javaObject;
for (float x : arr) ret.add(x);
return ret;
} else if (javaObject instanceof double[]) {
double[] arr = (double[]) javaObject;
for (double x : arr) ret.add(x);
return ret;
} else if (javaObject instanceof boolean[]) {
boolean[] arr = (boolean[]) javaObject;
for (boolean x : arr) ret.add(x);
return ret;
} else {
throw new PythonException("Unsupported array type: " + javaObject.getClass().toString());
}
} else { } else {
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to List"); throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to List");
} }
@ -327,7 +390,13 @@ public class PythonTypes {
} }
Object v = javaObject.get(k); Object v = javaObject.get(k);
PythonObject pyVal; PythonObject pyVal;
pyVal = PythonTypes.convert(v); if (v instanceof PythonObject) {
pyVal = (PythonObject) v;
} else if (v instanceof PyObject) {
pyVal = new PythonObject((PyObject) v);
} else {
pyVal = PythonTypes.convert(v);
}
int errCode = PyDict_SetItem(pyDict, pyKey.getNativePythonObject(), pyVal.getNativePythonObject()); int errCode = PyDict_SetItem(pyDict, pyKey.getNativePythonObject(), pyVal.getNativePythonObject());
if (errCode != 0) { if (errCode != 0) {
String keyStr = pyKey.toString(); String keyStr = pyKey.toString();
@ -341,4 +410,127 @@ public class PythonTypes {
return new PythonObject(pyDict); return new PythonObject(pyDict);
} }
}; };
public static final PythonType<byte[]> BYTES = new PythonType<byte[]>("bytes", byte[].class) {
@Override
public byte[] toJava(PythonObject pythonObject) {
try (PythonGC gc = PythonGC.watch()) {
if (!(Python.isinstance(pythonObject, Python.bytesType()))) {
throw new PythonException("Expected bytes. Received: " + pythonObject);
}
PythonObject pySize = Python.len(pythonObject);
byte[] ret = new byte[pySize.toInt()];
for (int i = 0; i < ret.length; i++) {
ret[i] = (byte)pythonObject.get(i).toInt();
}
return ret;
}
}
@Override
public PythonObject toPython(byte[] javaObject) {
try(PythonGC gc = PythonGC.watch()){
PythonObject ret = Python.bytes(LIST.toPython(LIST.adapt(javaObject)));
PythonGC.keep(ret);
return ret;
}
}
@Override
public boolean accepts(Object javaObject) {
return javaObject instanceof byte[];
}
@Override
public byte[] adapt(Object javaObject) {
if (javaObject instanceof byte[]){
return (byte[])javaObject;
}
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to byte[]");
}
};
/**
* Crashes on Adopt OpenJDK
* Use implementation in python4j-numpy instead for zero-copy byte buffers.
*/
// public static final PythonType<BytePointer> MEMORYVIEW = new PythonType<BytePointer>("memoryview", BytePointer.class) {
// @Override
// public BytePointer toJava(PythonObject pythonObject) {
// try (PythonGC gc = PythonGC.watch()) {
// if (!(Python.isinstance(pythonObject, Python.memoryviewType()))) {
// throw new PythonException("Expected memoryview. Received: " + pythonObject);
// }
// PythonObject pySize = Python.len(pythonObject);
// PythonObject ctypes = Python.importModule("ctypes");
// PythonObject charType = ctypes.attr("c_char");
// PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(),
// pySize.getNativePythonObject()));
// PythonObject fromBuffer = charArrayType.attr("from_buffer");
// if (pythonObject.attr("readonly").toBoolean()) {
// pythonObject = Python.bytearray(pythonObject);
// }
// PythonObject arr = fromBuffer.call(pythonObject);
// PythonObject cast = ctypes.attr("cast");
// PythonObject voidPtrType = ctypes.attr("c_void_p");
// PythonObject voidPtr = cast.call(arr, voidPtrType);
// long address = voidPtr.attr("value").toLong();
// long size = pySize.toLong();
// try {
// Field addressField = Buffer.class.getDeclaredField("address");
// addressField.setAccessible(true);
// Field capacityField = Buffer.class.getDeclaredField("capacity");
// capacityField.setAccessible(true);
// ByteBuffer buff = ByteBuffer.allocateDirect(0).order(ByteOrder.nativeOrder());
// addressField.setLong(buff, address);
// capacityField.setInt(buff, (int) size);
// BytePointer ret = new BytePointer(buff);
// ret.limit(size);
// return ret;
//
// } catch (Exception e) {
// throw new RuntimeException(e);
// }
//
// }
// }
//
// @Override
// public PythonObject toPython(BytePointer javaObject) {
// long address = javaObject.address();
// long size = javaObject.limit();
// try (PythonGC gc = PythonGC.watch()) {
// PythonObject ctypes = Python.importModule("ctypes");
// PythonObject charType = ctypes.attr("c_char");
// PythonObject pySize = new PythonObject(size);
// PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(),
// pySize.getNativePythonObject()));
// PythonObject fromAddress = charArrayType.attr("from_address");
// PythonObject arr = fromAddress.call(new PythonObject(address));
// PythonObject memoryView = Python.memoryview(arr).attr("cast").call("b");
// PythonGC.keep(memoryView);
// return memoryView;
// }
//
// }
//
// @Override
// public boolean accepts(Object javaObject) {
// return javaObject instanceof Pointer || javaObject instanceof DirectBuffer;
// }
//
// @Override
// public BytePointer adapt(Object javaObject) {
// if (javaObject instanceof BytePointer) {
// return (BytePointer) javaObject;
// } else if (javaObject instanceof Pointer) {
// return new BytePointer((Pointer) javaObject);
// } else if (javaObject instanceof DirectBuffer) {
// return new BytePointer((ByteBuffer) javaObject);
// } else {
// throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to BytePointer");
// }
// }
// };
} }

View File

@ -14,7 +14,7 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.eclipse.python4j; package org.nd4j.python4j;
@lombok.Data @lombok.Data
public class PythonVariable<T> { public class PythonVariable<T> {

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.python4j;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Some syntax sugar for lookup by name
*/
public class PythonVariables extends ArrayList<PythonVariable> {
public PythonVariable get(String variableName) {
for (PythonVariable pyVar: this){
if (pyVar.getName().equals(variableName)){
return pyVar;
}
}
return null;
}
public <T> boolean add(String variableName, PythonType<T> variableType, Object value){
return this.add(new PythonVariable<>(variableName, variableType, value));
}
public PythonVariables(PythonVariable... variables){
this(Arrays.asList(variables));
}
public PythonVariables(List<PythonVariable> list){
super();
addAll(list);
}
}

View File

@ -15,9 +15,12 @@
******************************************************************************/ ******************************************************************************/
import org.eclipse.python4j.*;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import org.nd4j.python4j.PythonContextManager;
import org.nd4j.python4j.PythonExecutioner;
import org.nd4j.python4j.PythonTypes;
import org.nd4j.python4j.PythonVariable;
import javax.annotation.concurrent.NotThreadSafe; import javax.annotation.concurrent.NotThreadSafe;
import java.util.*; import java.util.*;

View File

@ -15,9 +15,9 @@
******************************************************************************/ ******************************************************************************/
import org.eclipse.python4j.PythonException; import org.nd4j.python4j.PythonException;
import org.eclipse.python4j.PythonObject; import org.nd4j.python4j.PythonObject;
import org.eclipse.python4j.PythonTypes; import org.nd4j.python4j.PythonTypes;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;

View File

@ -16,9 +16,9 @@
******************************************************************************/ ******************************************************************************/
import org.eclipse.python4j.Python; import org.nd4j.python4j.Python;
import org.eclipse.python4j.PythonContextManager; import org.nd4j.python4j.PythonContextManager;
import org.eclipse.python4j.PythonExecutioner; import org.nd4j.python4j.PythonExecutioner;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import javax.annotation.concurrent.NotThreadSafe; import javax.annotation.concurrent.NotThreadSafe;

View File

@ -14,9 +14,9 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
import org.eclipse.python4j.Python; import org.nd4j.python4j.Python;
import org.eclipse.python4j.PythonGC; import org.nd4j.python4j.PythonGC;
import org.eclipse.python4j.PythonObject; import org.nd4j.python4j.PythonObject;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
@ -49,6 +49,6 @@ public class PythonGCTest {
PythonObject pyObjCount3 = Python.len(getObjects.call()); PythonObject pyObjCount3 = Python.len(getObjects.call());
long objCount3 = pyObjCount3.toLong(); long objCount3 = pyObjCount3.toLong();
diff = objCount3 - objCount2; diff = objCount3 - objCount2;
Assert.assertEquals(2, diff);// 2 objects created during function call Assert.assertTrue(diff <= 2);// 2 objects created during function call
} }
} }

View File

@ -14,10 +14,10 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
import org.eclipse.python4j.PythonContextManager; import org.nd4j.python4j.PythonContextManager;
import org.eclipse.python4j.PythonJob; import org.nd4j.python4j.PythonJob;
import org.eclipse.python4j.PythonTypes; import org.nd4j.python4j.PythonTypes;
import org.eclipse.python4j.PythonVariable; import org.nd4j.python4j.PythonVariable;
import org.junit.Test; import org.junit.Test;
import java.util.ArrayList; import java.util.ArrayList;
@ -30,7 +30,7 @@ import static org.junit.Assert.assertEquals;
public class PythonJobTest { public class PythonJobTest {
@Test @Test
public void testPythonJobBasic() throws Exception{ public void testPythonJobBasic(){
PythonContextManager.deleteNonMainContexts(); PythonContextManager.deleteNonMainContexts();
String code = "c = a + b"; String code = "c = a + b";
@ -65,7 +65,7 @@ public class PythonJobTest {
} }
@Test @Test
public void testPythonJobReturnAllVariables()throws Exception{ public void testPythonJobReturnAllVariables(){
PythonContextManager.deleteNonMainContexts(); PythonContextManager.deleteNonMainContexts();
String code = "c = a + b"; String code = "c = a + b";
@ -101,7 +101,7 @@ public class PythonJobTest {
@Test @Test
public void testMultiplePythonJobsParallel()throws Exception{ public void testMultiplePythonJobsParallel(){
PythonContextManager.deleteNonMainContexts(); PythonContextManager.deleteNonMainContexts();
String code1 = "c = a + b"; String code1 = "c = a + b";
PythonJob job1 = new PythonJob("job1", code1, false); PythonJob job1 = new PythonJob("job1", code1, false);
@ -150,7 +150,7 @@ public class PythonJobTest {
@Test @Test
public void testPythonJobSetupRun()throws Exception{ public void testPythonJobSetupRun(){
PythonContextManager.deleteNonMainContexts(); PythonContextManager.deleteNonMainContexts();
String code = "five=None\n" + String code = "five=None\n" +
@ -189,7 +189,7 @@ public class PythonJobTest {
} }
@Test @Test
public void testPythonJobSetupRunAndReturnAllVariables()throws Exception{ public void testPythonJobSetupRunAndReturnAllVariables(){
PythonContextManager.deleteNonMainContexts(); PythonContextManager.deleteNonMainContexts();
String code = "five=None\n" + String code = "five=None\n" +
"c=None\n"+ "c=None\n"+
@ -225,7 +225,7 @@ public class PythonJobTest {
} }
@Test @Test
public void testMultiplePythonJobsSetupRunParallel()throws Exception{ public void testMultiplePythonJobsSetupRunParallel(){
PythonContextManager.deleteNonMainContexts(); PythonContextManager.deleteNonMainContexts();
String code1 = "five=None\n" + String code1 = "five=None\n" +

View File

@ -14,10 +14,9 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
import org.eclipse.python4j.*; import org.nd4j.python4j.*;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import javax.annotation.concurrent.NotThreadSafe; import javax.annotation.concurrent.NotThreadSafe;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;

View File

@ -15,12 +15,13 @@
******************************************************************************/ ******************************************************************************/
import org.eclipse.python4j.PythonException; import org.nd4j.python4j.*;
import org.eclipse.python4j.PythonObject;
import org.eclipse.python4j.PythonTypes;
import org.junit.Assert; import org.junit.Assert;
import org.junit.Test; import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
public class PythonPrimitiveTypesTest { public class PythonPrimitiveTypesTest {
@Test @Test
@ -78,5 +79,18 @@ public class PythonPrimitiveTypesTest {
Assert.assertEquals(b, b3); Assert.assertEquals(b, b3);
} }
@Test
public void testBytes() {
byte[] bytes = new byte[]{97, 98, 99};
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("buff", PythonTypes.BYTES, bytes));
List<PythonVariable> outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("s1", PythonTypes.STR));
outputs.add(new PythonVariable<>("buff2", PythonTypes.BYTES));
String code = "s1 = ''.join(chr(c) for c in buff)\nbuff2=b'def'";
PythonExecutioner.exec(code, inputs, outputs);
Assert.assertEquals("abc", outputs.get(0).getValue());
Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[])outputs.get(1).getValue());
}
} }

View File

@ -4,7 +4,7 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent> <parent>
<artifactId>python4j-parent</artifactId> <artifactId>python4j-parent</artifactId>
<groupId>org.eclipse</groupId> <groupId>org.nd4j</groupId>
<version>1.0.0-SNAPSHOT</version> <version>1.0.0-SNAPSHOT</version>
</parent> </parent>
<modelVersion>4.0.0</modelVersion> <modelVersion>4.0.0</modelVersion>
@ -28,15 +28,50 @@
<version>${nd4j.version}</version> <version>${nd4j.version}</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>python4j-core</artifactId>
<version>1.0.0-SNAPSHOT</version>
</dependency>
</dependencies> </dependencies>
<profiles> <profiles>
<profile> <profile>
<id>test-nd4j-native</id> <id>test-nd4j-native</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-test-resources</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile> </profile>
<profile> <profile>
<id>test-nd4j-cuda-10.2</id> <id>test-nd4j-cuda-10.2</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-test-resources</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile> </profile>
</profiles> </profiles>
</project> </project>

View File

@ -0,0 +1,299 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.python4j;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.cpython.PyObject;
import org.bytedeco.cpython.PyTypeObject;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.numpy.PyArrayObject;
import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.io.File;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.bytedeco.cpython.global.python.*;
import static org.bytedeco.cpython.global.python.Py_DecRef;
import static org.bytedeco.numpy.global.numpy.*;
import static org.bytedeco.numpy.global.numpy.NPY_ARRAY_CARRAY;
import static org.bytedeco.numpy.global.numpy.PyArray_Type;
@Slf4j
public class NumpyArray extends PythonType<INDArray> {
public static final NumpyArray INSTANCE;
private static final AtomicBoolean init = new AtomicBoolean(false);
private static final Map<String, DataBuffer> cache = new HashMap<>();
static {
new PythonExecutioner();
INSTANCE = new NumpyArray();
}
@Override
public File[] packages(){
try{
return new File[]{numpy.cachePackage()};
}catch(Exception e){
throw new PythonException(e);
}
}
public synchronized void init() {
if (init.get()) return;
init.set(true);
if (PythonGIL.locked()) {
throw new PythonException("Can not initialize numpy - GIL already acquired.");
}
int err = numpy._import_array();
if (err < 0){
System.out.println("Numpy import failed!");
throw new PythonException("Numpy import failed!");
}
}
public NumpyArray() {
super("numpy.ndarray", INDArray.class);
}
@Override
public INDArray toJava(PythonObject pythonObject) {
log.info("Converting PythonObject to INDArray...");
PyObject np = PyImport_ImportModule("numpy");
PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
if (PyObject_IsInstance(pythonObject.getNativePythonObject(), ndarray) != 1) {
Py_DecRef(ndarray);
Py_DecRef(np);
throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
}
Py_DecRef(ndarray);
Py_DecRef(np);
PyArrayObject npArr = new PyArrayObject(pythonObject.getNativePythonObject());
long[] shape = new long[PyArray_NDIM(npArr)];
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
if (shapePtr != null)
shapePtr.get(shape, 0, shape.length);
long[] strides = new long[shape.length];
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
if (stridesPtr != null)
stridesPtr.get(strides, 0, strides.length);
int npdtype = PyArray_TYPE(npArr);
DataType dtype;
switch (npdtype) {
case NPY_DOUBLE:
dtype = DataType.DOUBLE;
break;
case NPY_FLOAT:
dtype = DataType.FLOAT;
break;
case NPY_SHORT:
dtype = DataType.SHORT;
break;
case NPY_INT:
dtype = DataType.INT32;
break;
case NPY_LONG:
dtype = DataType.INT64;
break;
case NPY_UINT:
dtype = DataType.UINT32;
break;
case NPY_BYTE:
dtype = DataType.INT8;
break;
case NPY_UBYTE:
dtype = DataType.UINT8;
break;
case NPY_BOOL:
dtype = DataType.BOOL;
break;
case NPY_HALF:
dtype = DataType.FLOAT16;
break;
case NPY_LONGLONG:
dtype = DataType.INT64;
break;
case NPY_USHORT:
dtype = DataType.UINT16;
break;
case NPY_ULONG:
case NPY_ULONGLONG:
dtype = DataType.UINT64;
break;
default:
throw new PythonException("Unsupported array data type: " + npdtype);
}
long size = 1;
for (int i = 0; i < shape.length; size *= shape[i++]) ;
INDArray ret;
long address = PyArray_DATA(npArr).address();
String key = address + "_" + size + "_" + dtype;
DataBuffer buff = cache.get(key);
if (buff == null) {
try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address);
ptr = ptr.limit(size);
ptr = ptr.capacity(size);
buff = Nd4j.createBuffer(ptr, size, dtype);
cache.put(key, buff);
}
}
int elemSize = buff.getElementSize();
long[] nd4jStrides = new long[strides.length];
for (int i = 0; i < strides.length; i++) {
nd4jStrides[i] = strides[i] / elemSize;
}
ret = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype);
Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.HOST);
log.info("Done.");
return ret;
}
@Override
public PythonObject toPython(INDArray indArray) {
log.info("Converting INDArray to PythonObject...");
DataType dataType = indArray.dataType();
DataBuffer buff = indArray.data();
String key = buff.pointer().address() + "_" + buff.length() + "_" + dataType;
cache.put(key, buff);
int numpyType;
String ctype;
switch (dataType) {
case DOUBLE:
numpyType = NPY_DOUBLE;
ctype = "c_double";
break;
case FLOAT:
case BFLOAT16:
numpyType = NPY_FLOAT;
ctype = "c_float";
break;
case SHORT:
numpyType = NPY_SHORT;
ctype = "c_short";
break;
case INT:
numpyType = NPY_INT;
ctype = "c_int";
break;
case LONG:
numpyType = NPY_INT64;
ctype = "c_int64";
break;
case UINT16:
numpyType = NPY_USHORT;
ctype = "c_uint16";
break;
case UINT32:
numpyType = NPY_UINT;
ctype = "c_uint";
break;
case UINT64:
numpyType = NPY_UINT64;
ctype = "c_uint64";
break;
case BOOL:
numpyType = NPY_BOOL;
ctype = "c_bool";
break;
case BYTE:
numpyType = NPY_BYTE;
ctype = "c_byte";
break;
case UBYTE:
numpyType = NPY_UBYTE;
ctype = "c_ubyte";
break;
case HALF:
numpyType = NPY_HALF;
ctype = "c_short";
break;
default:
throw new RuntimeException("Unsupported dtype: " + dataType);
}
long[] shape = indArray.shape();
INDArray inputArray = indArray;
if (dataType == DataType.BFLOAT16) {
log.warn("Creating copy of array as bfloat16 is not supported by numpy.");
inputArray = indArray.castTo(DataType.FLOAT);
}
//Sync to host memory in the case of CUDA, before passing the host memory pointer to Python
Nd4j.getAffinityManager().ensureLocation(inputArray, AffinityManager.Location.HOST);
// PyArray_Type() call causes jvm crash in linux cpu if GIL is acquired by non main thread.
// Using Interpreter for now:
// try(PythonContextManager.Context context = new PythonContextManager.Context("__np_array_converter")){
// log.info("Stringing exec...");
// String code = "import ctypes\nimport numpy as np\n" +
// "cArr = (ctypes." + ctype + "*" + indArray.length() + ")"+
// ".from_address(" + indArray.data().pointer().address() + ")\n"+
// "npArr = np.frombuffer(cArr, dtype=" + ((numpyType == NPY_HALF) ? "'half'" : "ctypes." + ctype)+
// ").reshape(" + Arrays.toString(indArray.shape()) + ")";
// PythonExecutioner.exec(code);
// log.info("exec done.");
// PythonObject ret = PythonExecutioner.getVariable("npArr");
// Py_IncRef(ret.getNativePythonObject());
// return ret;
//
// }
log.info("NUMPY: PyArray_Type()");
PyTypeObject pyTypeObject = PyArray_Type();
log.info("NUMPY: PyArray_New()");
PyObject npArr = PyArray_New(pyTypeObject, shape.length, new SizeTPointer(shape),
numpyType, null,
inputArray.data().addressPointer(),
0, NPY_ARRAY_CARRAY, null);
log.info("Done.");
return new PythonObject(npArr);
}
@Override
public boolean accepts(Object javaObject) {
return javaObject instanceof INDArray;
}
@Override
public INDArray adapt(Object javaObject) {
if (javaObject instanceof INDArray) {
return (INDArray) javaObject;
}
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to INDArray");
}
}

View File

@ -0,0 +1 @@
org.nd4j.python4j.NumpyArray

View File

@ -0,0 +1,169 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.nd4j.python4j.*;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import javax.annotation.concurrent.NotThreadSafe;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
@NotThreadSafe
@RunWith(Parameterized.class)
public class PythonNumpyBasicTest {
private DataType dataType;
private long[] shape;
public PythonNumpyBasicTest(DataType dataType, long[] shape, String dummyArg) {
this.dataType = dataType;
this.shape = shape;
}
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}, shape={2}")
public static Collection params() {
DataType[] types = new DataType[] {
DataType.BOOL,
DataType.FLOAT16,
DataType.BFLOAT16,
DataType.FLOAT,
DataType.DOUBLE,
DataType.INT8,
DataType.INT16,
DataType.INT32,
DataType.INT64,
DataType.UINT8,
DataType.UINT16,
DataType.UINT32,
DataType.UINT64
};
long[][] shapes = new long[][]{
new long[]{2, 3},
new long[]{3},
new long[]{1},
new long[]{} // scalar
};
List<Object[]> ret = new ArrayList<>();
for (DataType type: types){
for (long[] shape: shapes){
ret.add(new Object[]{type, shape, Arrays.toString(shape)});
}
}
return ret;
}
@Test
public void testConversion(){
INDArray arr = Nd4j.zeros(dataType, shape);
PythonObject npArr = PythonTypes.convert(arr);
INDArray arr2 = PythonTypes.<INDArray>getPythonTypeForPythonObject(npArr).toJava(npArr);
if (dataType == DataType.BFLOAT16){
arr = arr.castTo(DataType.FLOAT);
}
Assert.assertEquals(arr,arr2);
}
@Test
public void testExecution(){
List<PythonVariable> inputs = new ArrayList<>();
INDArray x = Nd4j.ones(dataType, shape);
INDArray y = Nd4j.zeros(dataType, shape);
INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2));
z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z;
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
inputs.add(new PythonVariable<>("x", arrType, x));
inputs.add(new PythonVariable<>("y", arrType, y));
List<PythonVariable> outputs = new ArrayList<>();
PythonVariable<INDArray> output = new PythonVariable<>("z", arrType);
outputs.add(output);
String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)";
if (shape.length == 0){ // scalar special case
code += "\nimport numpy as np\nz = np.asarray(float(z), dtype=x.dtype)";
}
PythonExecutioner.exec(code, inputs, outputs);
INDArray z2 = output.getValue();
Assert.assertEquals(z.dataType(), z2.dataType());
Assert.assertEquals(z, z2);
}
@Test
public void testInplaceExecution(){
if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return;
if (shape.length == 0) return;
List<PythonVariable> inputs = new ArrayList<>();
INDArray x = Nd4j.ones(dataType, shape);
INDArray y = Nd4j.zeros(dataType, shape);
INDArray z = x.mul(y.add(2));
// Nd4j.getAffinityManager().ensureLocation(z, AffinityManager.Location.HOST);
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
inputs.add(new PythonVariable<>("x", arrType, x));
inputs.add(new PythonVariable<>("y", arrType, y));
List<PythonVariable> outputs = new ArrayList<>();
PythonVariable<INDArray> output = new PythonVariable<>("x", arrType);
outputs.add(output);
String code = "x *= y + 2";
PythonExecutioner.exec(code, inputs, outputs);
INDArray z2 = output.getValue();
Assert.assertEquals(x.dataType(), z2.dataType());
Assert.assertEquals(z.dataType(), z2.dataType());
Assert.assertEquals(x, z2);
Assert.assertEquals(z, z2);
Assert.assertEquals(x.data().pointer().address(), z2.data().pointer().address());
if("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
Assert.assertEquals(getDeviceAddress(x), getDeviceAddress(z2));
}
}
private static long getDeviceAddress(INDArray array){
if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
throw new IllegalStateException("Cannot ge device pointer for non-CUDA device");
}
//Use reflection here as OpaqueDataBuffer is only available on BaseCudaDataBuffer and BaseCpuDataBuffer - not DataBuffer/BaseDataBuffer
// due to it being defined in nd4j-native-api, not nd4j-api
try {
Class<?> c = Class.forName("org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer");
Method m = c.getMethod("getOpaqueDataBuffer");
OpaqueDataBuffer db = (OpaqueDataBuffer) m.invoke(array.data());
long address = db.specialBuffer().address();
return address;
} catch (Throwable t){
throw new RuntimeException("Error getting OpaqueDataBuffer", t);
}
}
}

View File

@ -0,0 +1,96 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.nd4j.python4j.PythonException;
import org.nd4j.python4j.PythonObject;
import org.nd4j.python4j.PythonTypes;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.*;
@NotThreadSafe
@RunWith(Parameterized.class)
public class PythonNumpyCollectionsTest {
private DataType dataType;
public PythonNumpyCollectionsTest(DataType dataType){
this.dataType = dataType;
}
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
public static DataType[] params() {
return new DataType[]{
DataType.BOOL,
DataType.FLOAT16,
//DataType.BFLOAT16,
DataType.FLOAT,
DataType.DOUBLE,
DataType.INT8,
DataType.INT16,
DataType.INT32,
DataType.INT64,
DataType.UINT8,
DataType.UINT16,
DataType.UINT32,
DataType.UINT64
};
}
@Test
public void testPythonDictFromMap() throws PythonException {
Map map = new HashMap();
map.put("a", 1);
map.put(1, "a");
map.put("arr", Nd4j.ones(dataType, 2, 3));
map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType,3,2)));
Map innerMap = new HashMap();
innerMap.put("b", 2);
innerMap.put(2, "b");
innerMap.put(5, Nd4j.ones(dataType, 5));
map.put("innermap", innerMap);
map.put("list2", Arrays.asList(4, "5", innerMap, false, true));
PythonObject dict = PythonTypes.convert(map);
Map map2 = PythonTypes.DICT.toJava(dict);
Assert.assertEquals(map.toString(), map2.toString());
}
@Test
public void testPythonListFromList() throws PythonException{
List<Object> list = new ArrayList<>();
list.add(1);
list.add("2");
list.add(Nd4j.ones(dataType, 2, 3));
list.add(Arrays.asList("a",
Nd4j.ones(dataType, 1, 2),1.0, 2f, 10, true, false,
Nd4j.zeros(dataType, 3, 2)));
Map map = new HashMap();
map.put("a", 1);
map.put(1, "a");
map.put(5, Nd4j.ones(dataType,4, 5));
map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType, 3, 1)));
list.add(map);
PythonObject dict = PythonTypes.convert(list);
List list2 = PythonTypes.LIST.toJava(dict);
Assert.assertEquals(list.toString(), list2.toString());
}
}

View File

@ -0,0 +1,55 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.nd4j.python4j.Python;
import org.nd4j.python4j.PythonGC;
import org.nd4j.python4j.PythonObject;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
@NotThreadSafe
public class PythonNumpyGCTest {
@Test
public void testGC(){
PythonObject gcModule = Python.importModule("gc");
PythonObject getObjects = gcModule.attr("get_objects");
PythonObject pyObjCount1 = Python.len(getObjects.call());
long objCount1 = pyObjCount1.toLong();
PythonObject pyList = Python.list();
pyList.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10)));
pyList.attr("append").call(1.0);
pyList.attr("append").call(true);
PythonObject pyObjCount2 = Python.len(getObjects.call());
long objCount2 = pyObjCount2.toLong();
long diff = objCount2 - objCount1;
Assert.assertTrue(diff > 2);
try(PythonGC gc = PythonGC.watch()){
PythonObject pyList2 = Python.list();
pyList2.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10)));
pyList2.attr("append").call(1.0);
pyList2.attr("append").call(true);
}
PythonObject pyObjCount3 = Python.len(getObjects.call());
long objCount3 = pyObjCount3.toLong();
diff = objCount3 - objCount2;
Assert.assertTrue(diff <= 2);// 2 objects created during function call
}
}

View File

@ -0,0 +1,22 @@
import org.nd4j.python4j.NumpyArray;
import org.nd4j.python4j.Python;
import org.nd4j.python4j.PythonGC;
import org.nd4j.python4j.PythonObject;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class PythonNumpyImportTest {
@Test
public void testNumpyImport(){
try(PythonGC gc = PythonGC.watch()){
PythonObject np = Python.importModule("numpy");
PythonObject zeros = np.attr("zeros").call(5);
INDArray arr = NumpyArray.INSTANCE.toJava(zeros);
Assert.assertEquals(arr, Nd4j.zeros(DataType.DOUBLE, 5));
}
}
}

View File

@ -0,0 +1,303 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.python4j.*;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
@javax.annotation.concurrent.NotThreadSafe
@RunWith(Parameterized.class)
public class PythonNumpyJobTest {
private DataType dataType;
public PythonNumpyJobTest(DataType dataType){
this.dataType = dataType;
}
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
public static DataType[] params() {
return new DataType[]{
DataType.BOOL,
DataType.FLOAT16,
DataType.BFLOAT16,
DataType.FLOAT,
DataType.DOUBLE,
DataType.INT8,
DataType.INT16,
DataType.INT32,
DataType.INT64,
DataType.UINT8,
DataType.UINT16,
DataType.UINT32,
DataType.UINT64
};
}
@Test
public void testNumpyJobBasic(){
PythonContextManager.deleteNonMainContexts();
List<PythonVariable> inputs = new ArrayList<>();
INDArray x = Nd4j.ones(dataType, 2, 3);
INDArray y = Nd4j.zeros(dataType, 2, 3);
INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2));
z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z;
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
inputs.add(new PythonVariable<>("x", arrType, x));
inputs.add(new PythonVariable<>("y", arrType, y));
List<PythonVariable> outputs = new ArrayList<>();
PythonVariable<INDArray> output = new PythonVariable<>("z", arrType);
outputs.add(output);
String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)";
PythonJob job = new PythonJob("job1", code, false);
job.exec(inputs, outputs);
INDArray z2 = output.getValue();
if (dataType == DataType.BFLOAT16){
z2 = z2.castTo(DataType.FLOAT);
}
Assert.assertEquals(z, z2);
}
@Test
public void testNumpyJobReturnAllVariables(){
PythonContextManager.deleteNonMainContexts();
List<PythonVariable> inputs = new ArrayList<>();
INDArray x = Nd4j.ones(dataType, 2, 3);
INDArray y = Nd4j.zeros(dataType, 2, 3);
INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2));
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
inputs.add(new PythonVariable<>("x", arrType, x));
inputs.add(new PythonVariable<>("y", arrType, y));
String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)";
PythonJob job = new PythonJob("job1", code, false);
List<PythonVariable> outputs = job.execAndReturnAllVariables(inputs);
INDArray x2 = (INDArray) outputs.get(0).getValue();
INDArray y2 = (INDArray) outputs.get(1).getValue();
INDArray z2 = (INDArray) outputs.get(2).getValue();
if (dataType == DataType.BFLOAT16){
x = x.castTo(DataType.FLOAT);
y = y.castTo(DataType.FLOAT);
z = z.castTo(DataType.FLOAT);
}
Assert.assertEquals(x, x2);
Assert.assertEquals(y, y2);
Assert.assertEquals(z, z2);
}
@Test
public void testMultipleNumpyJobsParallel(){
PythonContextManager.deleteNonMainContexts();
String code1 =(dataType == DataType.BOOL)?"z = x":"z = x + y";
PythonJob job1 = new PythonJob("job1", code1, false);
String code2 =(dataType == DataType.BOOL)?"z = y":"z = x - y";
PythonJob job2 = new PythonJob("job2", code2, false);
List<PythonVariable> inputs = new ArrayList<>();
INDArray x = Nd4j.ones(dataType, 2, 3);
INDArray y = Nd4j.zeros(dataType, 2, 3);
INDArray z1 = (dataType == DataType.BOOL)?x:x.add(y);
z1 = (dataType == DataType.BFLOAT16)? z1.castTo(DataType.FLOAT): z1;
INDArray z2 = (dataType == DataType.BOOL)?y:x.sub(y);
z2 = (dataType == DataType.BFLOAT16)? z2.castTo(DataType.FLOAT): z2;
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
inputs.add(new PythonVariable<>("x", arrType, x));
inputs.add(new PythonVariable<>("y", arrType, y));
List<PythonVariable> outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("z", arrType));
job1.exec(inputs, outputs);
assertEquals(z1, outputs.get(0).getValue());
job2.exec(inputs, outputs);
assertEquals(z2, outputs.get(0).getValue());
}
@Test
public synchronized void testNumpyJobSetupRun(){
if (dataType == DataType.BOOL)return;
PythonContextManager.deleteNonMainContexts();
String code = "five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b + five\n"+
" return {'c':c}\n\n";
PythonJob job = new PythonJob("job1", code, true);
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2)));
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
List<PythonVariable> outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
job.exec(inputs, outputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10),
outputs.get(0).getValue());
inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
job.exec(inputs, outputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12),
outputs.get(0).getValue());
}
@Test
public void testNumpyJobSetupRunAndReturnAllVariables(){
if (dataType == DataType.BOOL)return;
PythonContextManager.deleteNonMainContexts();
String code = "five=None\n" +
"c=None\n"+
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" global c\n" +
" c = a + b + five\n";
PythonJob job = new PythonJob("job1", code, true);
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2)));
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
List<PythonVariable> outputs = job.execAndReturnAllVariables(inputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10),
outputs.get(1).getValue());
inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
outputs = job.execAndReturnAllVariables(inputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12),
outputs.get(1).getValue());
}
@Test
public void testMultipleNumpyJobsSetupRunParallel(){
if (dataType == DataType.BOOL)return;
PythonContextManager.deleteNonMainContexts();
String code1 = "five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b + five\n"+
" return {'c':c}\n\n";
PythonJob job1 = new PythonJob("job1", code1, true);
String code2 = "five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b - five\n"+
" return {'c':c}\n\n";
PythonJob job2 = new PythonJob("job2", code2, true);
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2)));
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
List<PythonVariable> outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
job1.exec(inputs, outputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10),
outputs.get(0).getValue());
job2.exec(inputs, outputs);
assertEquals(Nd4j.zeros((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3),
outputs.get(0).getValue());
inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
job1.exec(inputs, outputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12),
outputs.get(0).getValue());
job2.exec(inputs, outputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(2),
outputs.get(0).getValue());
}
}

View File

@ -0,0 +1,194 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.nd4j.python4j.*;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@NotThreadSafe
@RunWith(Parameterized.class)
public class PythonNumpyMultiThreadTest {
private DataType dataType;
public PythonNumpyMultiThreadTest(DataType dataType) {
this.dataType = dataType;
}
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
public static DataType[] params() {
return new DataType[]{
// DataType.BOOL,
// DataType.FLOAT16,
// DataType.BFLOAT16,
DataType.FLOAT,
DataType.DOUBLE,
// DataType.INT8,
// DataType.INT16,
DataType.INT32,
DataType.INT64,
// DataType.UINT8,
// DataType.UINT16,
// DataType.UINT32,
// DataType.UINT64
};
}
@Test
public void testMultiThreading1() throws Throwable {
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
Runnable runnable = new Runnable() {
@Override
public void run() {
try (PythonGIL gil = PythonGIL.lock()) {
try (PythonGC gc = PythonGC.watch()) {
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
PythonVariable out = new PythonVariable<>("z", NumpyArray.INSTANCE);
String code = "z = x + y";
PythonExecutioner.exec(code, inputs, Collections.singletonList(out));
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), out.getValue());
}
} catch (Throwable e) {
exceptions.add(e);
}
}
};
int numThreads = 10;
Thread[] threads = new Thread[numThreads];
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(runnable);
}
for (int i = 0; i < threads.length; i++) {
threads[i].start();
}
Thread.sleep(100);
for (int i = 0; i < threads.length; i++) {
threads[i].join();
}
if (!exceptions.isEmpty()) {
throw (exceptions.get(0));
}
}
@Test
public void testMultiThreading2() throws Throwable {
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
Runnable runnable = new Runnable() {
@Override
public void run() {
try (PythonGIL gil = PythonGIL.lock()) {
try (PythonGC gc = PythonGC.watch()) {
PythonContextManager.reset();
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
String code = "z = x + y";
List<PythonVariable> outputs = PythonExecutioner.execAndReturnAllVariables(code, inputs);
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(3), outputs.get(0).getValue());
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(4), outputs.get(1).getValue());
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), outputs.get(2).getValue());
}
} catch (Throwable e) {
exceptions.add(e);
}
}
};
int numThreads = 10;
Thread[] threads = new Thread[numThreads];
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(runnable);
}
for (int i = 0; i < threads.length; i++) {
threads[i].start();
}
Thread.sleep(100);
for (int i = 0; i < threads.length; i++) {
threads[i].join();
}
if (!exceptions.isEmpty()) {
throw (exceptions.get(0));
}
}
@Test
public void testMultiThreading3() throws Throwable {
PythonContextManager.deleteNonMainContexts();
String code = "c = a + b";
final PythonJob job = new PythonJob("job1", code, false);
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
class JobThread extends Thread {
private INDArray a, b, c;
public JobThread(INDArray a, INDArray b, INDArray c) {
this.a = a;
this.b = b;
this.c = c;
}
@Override
public void run() {
try {
PythonVariable<INDArray> out = new PythonVariable<>("c", NumpyArray.INSTANCE);
job.exec(Arrays.<PythonVariable>asList(new PythonVariable<>("a", NumpyArray.INSTANCE, a),
new PythonVariable<>("b", NumpyArray.INSTANCE, b)),
Collections.<PythonVariable>singletonList(out));
Assert.assertEquals(c, out.getValue());
} catch (Exception e) {
exceptions.add(e);
}
}
}
int numThreads = 10;
JobThread[] threads = new JobThread[numThreads];
for (int i = 0; i < threads.length; i++) {
threads[i] = new JobThread(Nd4j.zeros(dataType, 2, 3).add(i), Nd4j.zeros(dataType, 2, 3).add(i + 3),
Nd4j.zeros(dataType, 2, 3).add(2 * i + 3));
}
for (int i = 0; i < threads.length; i++) {
threads[i].start();
}
Thread.sleep(100);
for (int i = 0; i < threads.length; i++) {
threads[i].join();
}
if (!exceptions.isEmpty()) {
throw (exceptions.get(0));
}
}
}

View File

@ -1,5 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -14,15 +14,22 @@
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.python4j.NumpyArray;
import org.nd4j.python4j.PythonTypes;
/** import javax.annotation.concurrent.NotThreadSafe;
* An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network
* @NotThreadSafe
* @author Alexandre Boulanger public class PythonNumpyServiceLoaderTest {
*/
public interface TargetQNetworkSource extends QNetworkSource { @Test
IDQN getTargetQNetwork(); public void testServiceLoader(){
Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.<INDArray>get("numpy.ndarray"));
Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.getPythonTypeForJavaObject(Nd4j.zeros(1)));
}
} }

View File

@ -17,7 +17,6 @@ package org.deeplearning4j.rl4j.agent.update;
import lombok.Getter; import lombok.Getter;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN; import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
@ -28,13 +27,10 @@ import java.util.List;
// Temporary class that will be replaced with a more generic class that delegates gradient computation // Temporary class that will be replaced with a more generic class that delegates gradient computation
// and network update to sub components. // and network update to sub components.
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>, TargetQNetworkSource { public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> {
@Getter
private final IDQN qNetwork; private final IDQN qNetwork;
private final IDQN targetQNetwork;
@Getter
private IDQN targetQNetwork;
private final int targetUpdateFrequency; private final int targetUpdateFrequency;
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm; private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
@ -47,8 +43,8 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
this.targetQNetwork = qNetwork.clone(); this.targetQNetwork = qNetwork.clone();
this.targetUpdateFrequency = targetUpdateFrequency; this.targetUpdateFrequency = targetUpdateFrequency;
tdTargetAlgorithm = isDoubleDQN tdTargetAlgorithm = isDoubleDQN
? new DoubleDQN(this, gamma, errorClamp) ? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp)
: new StandardDQN(this, gamma, errorClamp); : new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp);
} }
@Override @Override
@ -56,7 +52,7 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch); DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
qNetwork.fit(targets.getFeatures(), targets.getLabels()); qNetwork.fit(targets.getFeatures(), targets.getLabels());
if(++updateCount % targetUpdateFrequency == 0) { if(++updateCount % targetUpdateFrequency == 0) {
targetQNetwork = qNetwork.clone(); targetQNetwork.copy(qNetwork);
} }
} }
} }

View File

@ -16,8 +16,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
/** /**
@ -28,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
*/ */
public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm { public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
private final TargetQNetworkSource qTargetNetworkSource; private final IOutputNeuralNet targetQNetwork;
/** /**
* In litterature, this corresponds to Q{net}(s(t+1), a) * In litterature, this corresponds to Q{net}(s(t+1), a)
@ -40,23 +39,21 @@ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
*/ */
protected INDArray targetQNetworkNextObservation; protected INDArray targetQNetworkNextObservation;
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) { protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
super(qTargetNetworkSource, gamma); super(qNetwork, gamma);
this.qTargetNetworkSource = qTargetNetworkSource; this.targetQNetwork = targetQNetwork;
} }
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp); super(qNetwork, gamma, errorClamp);
this.qTargetNetworkSource = qTargetNetworkSource; this.targetQNetwork = targetQNetwork;
} }
@Override @Override
protected void initComputation(INDArray observations, INDArray nextObservations) { protected void initComputation(INDArray observations, INDArray nextObservations) {
super.initComputation(observations, nextObservations); super.initComputation(observations, nextObservations);
qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations); qNetworkNextObservation = qNetwork.output(nextObservations);
IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork();
targetQNetworkNextObservation = targetQNetwork.output(nextObservations); targetQNetworkNextObservation = targetQNetwork.output(nextObservations);
} }
} }

View File

@ -17,7 +17,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource; import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
@ -30,7 +30,7 @@ import java.util.List;
*/ */
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> { public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
protected final QNetworkSource qNetworkSource; protected final IOutputNeuralNet qNetwork;
protected final double gamma; protected final double gamma;
private final double errorClamp; private final double errorClamp;
@ -38,12 +38,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
/** /**
* *
* @param qNetworkSource The source of the Q-Network * @param qNetwork The Q-Network
* @param gamma The discount factor * @param gamma The discount factor
* @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping. * @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping.
*/ */
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) { protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma, double errorClamp) {
this.qNetworkSource = qNetworkSource; this.qNetwork = qNetwork;
this.gamma = gamma; this.gamma = gamma;
this.errorClamp = errorClamp; this.errorClamp = errorClamp;
@ -52,12 +52,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
/** /**
* *
* @param qNetworkSource The source of the Q-Network * @param qNetwork The Q-Network
* @param gamma The discount factor * @param gamma The discount factor
* Note: Error clamping is disabled with this ctor * Note: Error clamping is disabled with this ctor
*/ */
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma) { protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma) {
this(qNetworkSource, gamma, Double.NaN); this(qNetwork, gamma, Double.NaN);
} }
/** /**
@ -89,8 +89,7 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
initComputation(observations, nextObservations); initComputation(observations, nextObservations);
INDArray updatedQValues = qNetworkSource.getQNetwork().output(observations); INDArray updatedQValues = qNetwork.output(observations);
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
Transition<Integer> transition = transitions.get(i); Transition<Integer> transition = transitions.get(i);
double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal()); double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal());

View File

@ -16,7 +16,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -32,12 +32,12 @@ public class DoubleDQN extends BaseDQNAlgorithm {
// In litterature, this corresponds to: max_{a}Q(s_{t+1}, a) // In litterature, this corresponds to: max_{a}Q(s_{t+1}, a)
private INDArray maxActionsFromQNetworkNextObservation; private INDArray maxActionsFromQNetworkNextObservation;
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) { public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
super(qTargetNetworkSource, gamma); super(qNetwork, targetQNetwork, gamma);
} }
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp); super(qNetwork, targetQNetwork, gamma, errorClamp);
} }
@Override @Override

View File

@ -16,7 +16,7 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource; import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -32,12 +32,12 @@ public class StandardDQN extends BaseDQNAlgorithm {
// In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a) // In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a)
private INDArray maxActionsFromQTargetNextObservation; private INDArray maxActionsFromQTargetNextObservation;
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) { public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
super(qTargetNetworkSource, gamma); super(qNetwork, targetQNetwork, gamma);
} }
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) { public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
super(qTargetNetworkSource, gamma, errorClamp); super(qNetwork, targetQNetwork, gamma, errorClamp);
} }
@Override @Override

View File

@ -1,28 +1,38 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc. * Copyright (c) 2020 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0. * https://www.apache.org/licenses/LICENSE-2.0.
* *
* Unless required by applicable law or agreed to in writing, software * Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations * License for the specific language governing permissions and limitations
* under the License. * under the License.
* *
* SPDX-License-Identifier: Apache-2.0 * SPDX-License-Identifier: Apache-2.0
******************************************************************************/ ******************************************************************************/
package org.deeplearning4j.rl4j.network;
package org.deeplearning4j.rl4j.learning.sync.qlearning;
import org.deeplearning4j.rl4j.observation.Observation;
import org.deeplearning4j.rl4j.network.dqn.IDQN; import org.nd4j.linalg.api.ndarray.INDArray;
/** /**
* An interface for all implementations capable of supplying a Q-Network * An interface defining the output aspect of a {@link NeuralNet}.
* */
* @author Alexandre Boulanger public interface IOutputNeuralNet {
*/ /**
public interface QNetworkSource { * Compute the output for the supplied observation.
IDQN getQNetwork(); * @param observation An {@link Observation}
} * @return The ouptut of the network
*/
INDArray output(Observation observation);
/**
* Compute the output for the supplied batch.
* @param batch
* @return The ouptut of the network
*/
INDArray output(INDArray batch);
}

View File

@ -17,6 +17,7 @@
package org.deeplearning4j.rl4j.network.dqn; package org.deeplearning4j.rl4j.network.dqn;
import org.deeplearning4j.nn.gradient.Gradient; import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.network.NeuralNet; import org.deeplearning4j.rl4j.network.NeuralNet;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
* This neural net quantify the value of each action given a state * This neural net quantify the value of each action given a state
* *
*/ */
public interface IDQN<NN extends IDQN> extends NeuralNet<NN> { public interface IDQN<NN extends IDQN> extends NeuralNet<NN>, IOutputNeuralNet {
boolean isRecurrent(); boolean isRecurrent();
@ -37,9 +38,6 @@ public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
void fit(INDArray input, INDArray[] labels); void fit(INDArray input, INDArray[] labels);
INDArray output(INDArray batch);
INDArray output(Observation observation);
INDArray[] outputAll(INDArray batch); INDArray[] outputAll(INDArray batch);
NN clone(); NN clone();

View File

@ -1,10 +1,13 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -13,16 +16,29 @@ import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class DoubleDQNTest { public class DoubleDQNTest {
@Mock
IOutputNeuralNet qNetworkMock;
@Mock
IOutputNeuralNet targetQNetworkMock;
@Before
public void setup() {
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
}
@Test @Test
public void when_isTerminal_expect_rewardValueAtIdx0() { public void when_isTerminal_expect_rewardValueAtIdx0() {
// Assemble // Assemble
MockDQN qNetwork = new MockDQN(); when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{ {
@ -31,7 +47,7 @@ public class DoubleDQNTest {
} }
}; };
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.computeTDTargets(transitions);
@ -46,9 +62,7 @@ public class DoubleDQNTest {
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
// Assemble // Assemble
MockDQN qNetwork = new MockDQN(); when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
MockDQN targetQNetwork = new MockDQN(-1.0);
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{ {
@ -57,7 +71,7 @@ public class DoubleDQNTest {
} }
}; };
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.computeTDTargets(transitions);
@ -72,9 +86,7 @@ public class DoubleDQNTest {
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
// Assemble // Assemble
MockDQN qNetwork = new MockDQN(); when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
MockDQN targetQNetwork = new MockDQN(-1.0);
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{ {
@ -87,7 +99,7 @@ public class DoubleDQNTest {
} }
}; };
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5); DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.computeTDTargets(transitions);

View File

@ -1,10 +1,13 @@
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm; package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
import org.deeplearning4j.rl4j.learning.sync.Transition; import org.deeplearning4j.rl4j.learning.sync.Transition;
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN; import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
import org.deeplearning4j.rl4j.observation.Observation; import org.deeplearning4j.rl4j.observation.Observation;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -12,17 +15,31 @@ import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import static org.junit.Assert.*; import static org.junit.Assert.assertEquals;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
@RunWith(MockitoJUnitRunner.class)
public class StandardDQNTest { public class StandardDQNTest {
@Mock
IOutputNeuralNet qNetworkMock;
@Mock
IOutputNeuralNet targetQNetworkMock;
@Before
public void setup() {
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
}
@Test @Test
public void when_isTerminal_expect_rewardValueAtIdx0() { public void when_isTerminal_expect_rewardValueAtIdx0() {
// Assemble // Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{ {
add(buildTransition(buildObservation(new double[]{1.1, 2.2}), add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
@ -30,7 +47,7 @@ public class StandardDQNTest {
} }
}; };
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.computeTDTargets(transitions);
@ -45,10 +62,6 @@ public class StandardDQNTest {
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() { public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
// Assemble // Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{ {
add(buildTransition(buildObservation(new double[]{1.1, 2.2}), add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
@ -56,7 +69,7 @@ public class StandardDQNTest {
} }
}; };
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.computeTDTargets(transitions);
@ -71,10 +84,6 @@ public class StandardDQNTest {
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() { public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
// Assemble // Assemble
MockDQN qNetwork = new MockDQN();
MockDQN targetQNetwork = new MockDQN();
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() { List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
{ {
add(buildTransition(buildObservation(new double[]{1.1, 2.2}), add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
@ -86,7 +95,7 @@ public class StandardDQNTest {
} }
}; };
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5); StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
// Act // Act
DataSet result = sut.computeTDTargets(transitions); DataSet result = sut.computeTDTargets(transitions);

View File

@ -1,26 +0,0 @@
package org.deeplearning4j.rl4j.learning.sync.support;
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
import org.deeplearning4j.rl4j.network.dqn.IDQN;
public class MockTargetQNetworkSource implements TargetQNetworkSource {
private final IDQN qNetwork;
private final IDQN targetQNetwork;
public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) {
this.qNetwork = qNetwork;
this.targetQNetwork = targetQNetwork;
}
@Override
public IDQN getTargetQNetwork() {
return targetQNetwork;
}
@Override
public IDQN getQNetwork() {
return qNetwork;
}
}