commit
43fd64358c
|
@ -57,10 +57,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
|||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||
import org.deeplearning4j.util.ModelSerializer;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
import org.junit.*;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.activations.Activation;
|
||||
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
|
@ -82,6 +80,7 @@ import org.nd4j.common.resources.Resources;
|
|||
|
||||
import java.io.ByteArrayInputStream;
|
||||
import java.io.ByteArrayOutputStream;
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.*;
|
||||
|
||||
|
@ -91,6 +90,9 @@ import static org.junit.Assert.*;
|
|||
@Slf4j
|
||||
public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
private static ComputationGraphConfiguration getIrisGraphConfiguration() {
|
||||
return new NeuralNetConfiguration.Builder().seed(12345)
|
||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
|
||||
|
@ -2177,4 +2179,40 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
|||
INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2);
|
||||
cg.fit(new DataSet(in, label));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMergeNchw() throws Exception {
|
||||
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||
.convolutionMode(ConvolutionMode.Same)
|
||||
.graphBuilder()
|
||||
.addInputs("in")
|
||||
.layer("l0", new ConvolutionLayer.Builder()
|
||||
.nOut(16)
|
||||
.kernelSize(2,2).stride(1,1)
|
||||
.build(), "in")
|
||||
.layer("l1", new ConvolutionLayer.Builder()
|
||||
.nOut(8)
|
||||
.kernelSize(2,2).stride(1,1)
|
||||
.build(), "in")
|
||||
.addVertex("merge", new MergeVertex(), "l0", "l1")
|
||||
.layer("out", new CnnLossLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge")
|
||||
.setOutputs("out")
|
||||
.setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC))
|
||||
.build();
|
||||
|
||||
ComputationGraph cg = new ComputationGraph(conf);
|
||||
cg.init();
|
||||
|
||||
INDArray[] in = new INDArray[]{Nd4j.rand(DataType.FLOAT, 1, 32, 32, 3)};
|
||||
INDArray out = cg.outputSingle(in);
|
||||
|
||||
File dir = testDir.newFolder();
|
||||
File f = new File(dir, "net.zip");
|
||||
cg.save(f);
|
||||
|
||||
ComputationGraph c2 = ComputationGraph.load(f, true);
|
||||
INDArray out2 = c2.outputSingle(in);
|
||||
|
||||
assertEquals(out, out2);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -66,8 +66,8 @@ import org.nd4j.shade.jackson.annotation.JsonProperty;
|
|||
* @author Alex Black
|
||||
*/
|
||||
@Data
|
||||
@JsonIgnoreProperties({"mask", "helper", "helperCountFail"})
|
||||
@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail"})
|
||||
@JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"})
|
||||
@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"})
|
||||
@Slf4j
|
||||
public class Dropout implements IDropout {
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.nn.conf.graph;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||
|
@ -38,6 +39,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
* -> [numExamples,depth1 + depth2,width,height]}<br>
|
||||
* @author Alex Black
|
||||
*/
|
||||
@Data
|
||||
public class MergeVertex extends GraphVertex {
|
||||
|
||||
protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format
|
||||
|
|
|
@ -141,7 +141,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
|||
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
|
||||
sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats());
|
||||
|
||||
System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||
// System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||
File f = testDir.newFolder();
|
||||
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
|
||||
int count = 0;
|
||||
|
@ -208,10 +208,10 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
|||
}
|
||||
|
||||
INDArray paramsAfter = after.params();
|
||||
System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||
System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||
System.out.println(Arrays.toString(
|
||||
Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||
// System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||
// System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||
// System.out.println(Arrays.toString(
|
||||
// Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||
assertNotEquals(paramsBefore, paramsAfter);
|
||||
|
||||
|
||||
|
@ -235,7 +235,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
|||
}
|
||||
|
||||
|
||||
@Test
|
||||
@Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985
|
||||
public void differentNetsTrainingTest() throws Exception {
|
||||
int batch = 3;
|
||||
|
||||
|
|
|
@ -131,6 +131,23 @@ if(NOT SD_CUDA)
|
|||
endif()
|
||||
endif()
|
||||
|
||||
#arm-compute entry
|
||||
if(${HELPERS_armcompute})
|
||||
find_package(ARMCOMPUTE REQUIRED)
|
||||
|
||||
if(ARMCOMPUTE_FOUND)
|
||||
message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}")
|
||||
set(HAVE_ARMCOMPUTE 1)
|
||||
# Add preprocessor definition for ARM Compute NEON
|
||||
add_definitions(-DARMCOMPUTENEON_ENABLED)
|
||||
#build our library with neon support
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon")
|
||||
include_directories(${ARMCOMPUTE_INCLUDE})
|
||||
message("----${ARMCOMPUTE_INCLUDE}---")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
|
||||
# new mkl-dnn entry
|
||||
if (${HELPERS_mkldnn})
|
||||
|
|
|
@ -146,6 +146,10 @@ if (HAVE_MKLDNN)
|
|||
file(GLOB_RECURSE CUSTOMOPS_MKLDNN_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h)
|
||||
endif()
|
||||
|
||||
if(HAVE_ARMCOMPUTE)
|
||||
file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/*.h)
|
||||
endif()
|
||||
|
||||
if(SD_CUDA)
|
||||
message("Build cublas")
|
||||
find_package(CUDA)
|
||||
|
@ -243,7 +247,7 @@ if(SD_CUDA)
|
|||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||
${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
|
||||
${CUSTOMOPS_GENERIC_SOURCES}
|
||||
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||
)
|
||||
|
||||
if (WIN32)
|
||||
|
@ -351,8 +355,8 @@ elseif(SD_CPU)
|
|||
add_definitions(-D__CPUBLAS__=true)
|
||||
add_library(samediff_obj OBJECT ${LEGACY_SOURCES}
|
||||
${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||
${OPS_SOURCES} ${PERF_SOURCES})
|
||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
|
||||
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
|
||||
if(IOS)
|
||||
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
|
||||
else()
|
||||
|
@ -378,12 +382,12 @@ elseif(SD_CPU)
|
|||
if (NOT BLAS_LIBRARIES)
|
||||
set(BLAS_LIBRARIES "")
|
||||
endif()
|
||||
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||
|
||||
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
|
||||
message(STATUS "Building minifier...")
|
||||
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
|
||||
target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||
target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||
endif()
|
||||
|
||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
################################################################################
|
||||
# Copyright (c) 2020 Konduit K.K.
|
||||
#
|
||||
# This program and the accompanying materials are made available under the
|
||||
# terms of the Apache License, Version 2.0 which is available at
|
||||
# https://www.apache.org/licenses/LICENSE-2.0.
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
################################################################################
|
||||
|
||||
|
||||
|
||||
### Find ARM COMPUTE LIBRARY STATIC libraries
|
||||
|
||||
SET (COMPUTE_INCLUDE_DIRS
|
||||
/usr/include
|
||||
${ARMCOMPUTE_ROOT}
|
||||
${ARMCOMPUTE_ROOT}/include
|
||||
${ARMCOMPUTE_ROOT}/applications
|
||||
${ARMCOMPUTE_ROOT}/applications/arm_compute
|
||||
)
|
||||
|
||||
|
||||
SET (COMPUTE_LIB_DIRS
|
||||
/lib
|
||||
/usr/lib
|
||||
${ARMCOMPUTE_ROOT}
|
||||
${ARMCOMPUTE_ROOT}/lib
|
||||
${ARMCOMPUTE_ROOT}/build
|
||||
)
|
||||
|
||||
find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h
|
||||
PATHS ${COMPUTE_INCLUDE_DIRS}
|
||||
NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h)
|
||||
|
||||
find_path(HALF_INCLUDE half/half.hpp)
|
||||
find_path(HALF_INCLUDE half/half.hpp
|
||||
PATHS ${ARMCOMPUTE_ROOT}/include
|
||||
NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
|
||||
include_directories(SYSTEM ${HALF_INCLUDE})
|
||||
|
||||
# Find the Arm Compute libraries if not already specified
|
||||
if (NOT DEFINED ARMCOMPUTE_LIBRARIES)
|
||||
|
||||
find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static
|
||||
PATHS ${COMPUTE_LIB_DIRS}
|
||||
PATH_SUFFIXES "Release"
|
||||
NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
|
||||
|
||||
find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static
|
||||
PATHS ${COMPUTE_LIB_DIRS}
|
||||
PATH_SUFFIXES "Release"
|
||||
NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
|
||||
# In case it wasn't there, try a default search (will work in cases where
|
||||
# the library has been installed into a standard location)
|
||||
find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static)
|
||||
find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static)
|
||||
|
||||
set(ARMCOMPUTE_LIBRARIES ${ARMCOMPUTE_LIBRARY} ${ARMCOMPUTE_CORE_LIBRARY} )
|
||||
endif()
|
||||
|
||||
|
||||
INCLUDE(FindPackageHandleStandardArgs)
|
||||
|
||||
FIND_PACKAGE_HANDLE_STANDARD_ARGS(ARMCOMPUTE REQUIRED_VARS ARMCOMPUTE_INCLUDE ARMCOMPUTE_LIBRARIES)
|
||||
|
|
@ -69,7 +69,7 @@ namespace cnpy {
|
|||
}
|
||||
};
|
||||
|
||||
struct ND4J_EXPORT npz_t : public std::unordered_map<std::string, NpyArray> {
|
||||
struct ND4J_EXPORT npz_t : public std::map<std::string, NpyArray> {
|
||||
void destruct() {
|
||||
npz_t::iterator it = this->begin();
|
||||
for(; it != this->end(); ++it) (*it).second.destruct();
|
||||
|
|
|
@ -3,6 +3,8 @@
|
|||
|
||||
#cmakedefine HAVE_MKLDNN
|
||||
|
||||
#cmakedefine HAVE_ARMCOMPUTE
|
||||
|
||||
#cmakedefine MKLDNN_PATH "@MKLDNN_PATH@"
|
||||
|
||||
#cmakedefine HAVE_OPENBLAS
|
||||
|
|
|
@ -45,18 +45,18 @@ namespace sd {
|
|||
DECLARE_TYPES(max_pool_with_argmax) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(sd::DataType::ANY)
|
||||
->setAllowedOutputTypes(0, DataType::INHERIT)
|
||||
->setAllowedOutputTypes(1, {ALL_INTS});
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS})
|
||||
->setAllowedOutputTypes(1, {ALL_INDICES});
|
||||
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(max_pool_with_argmax) {
|
||||
auto in = inputShape->at(0);
|
||||
auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64;
|
||||
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
|
||||
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, dtype));
|
||||
|
||||
auto in = inputShape->at(0);
|
||||
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
|
||||
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, DataType::INT64));
|
||||
|
||||
return SHAPELIST(valuesShape, indicesShape);
|
||||
return SHAPELIST(valuesShape, indicesShape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -215,7 +215,9 @@ namespace helpers {
|
|||
auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]);
|
||||
auto result = -1;
|
||||
//auto loop = PRAGMA_THREADS_FOR {
|
||||
auto start = column, stop = rowNum, increment = 1;
|
||||
auto start = column;
|
||||
auto stop = rowNum;
|
||||
auto increment = 1;
|
||||
for (auto rowCounter = start; rowCounter < stop; rowCounter++) {
|
||||
Nd4jLong xPos[] = {rowCounter, column};
|
||||
auto xIndex = shape::getOffset(compoundShape, xPos, 0);
|
||||
|
|
|
@ -73,7 +73,7 @@ namespace helpers {
|
|||
}
|
||||
|
||||
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES);
|
||||
BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||
//
|
||||
// implementation is based on following article:
|
||||
// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167
|
||||
|
||||
|
||||
|
||||
|
@ -31,96 +32,167 @@ namespace ops {
|
|||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// Fisher-Yates shuffle
|
||||
template <typename T>
|
||||
void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) {
|
||||
|
||||
for(Nd4jLong i = len-1; i > 0; --i) {
|
||||
const Nd4jLong j = rng.relativeLong(ind++) % (i + 1);
|
||||
if(i != j)
|
||||
math::nd4j_swap<T>(buff[i*ews], buff[j*ews]);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// mutual shuffle of two adjacent already shuffled ranges with length len1 and (totLen - len1) correspondingly
|
||||
template <typename T>
|
||||
static void mergeShuffle(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len1, const Nd4jLong& totLen, const Nd4jLong& ews, Nd4jLong ind) {
|
||||
|
||||
Nd4jLong beg = 0; // beginning
|
||||
Nd4jLong mid = len1; // middle
|
||||
|
||||
while (true) {
|
||||
if(rng.relativeLong(ind++) % 2) {
|
||||
if(mid == totLen)
|
||||
break;
|
||||
math::nd4j_swap<T>(buff[ews * beg], buff[ews * mid++]);
|
||||
} else {
|
||||
if(beg == mid)
|
||||
break;
|
||||
}
|
||||
++beg;
|
||||
}
|
||||
|
||||
// fisherYates
|
||||
while (beg < totLen) {
|
||||
const Nd4jLong j = rng.relativeLong(ind++) % (beg + 1);
|
||||
if(beg != j)
|
||||
math::nd4j_swap<T>(buff[ews * beg], buff[ews * j]);
|
||||
++beg;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
|
||||
// check edge cases first
|
||||
int temp;
|
||||
const int firstDim = input.sizeAt(0);
|
||||
int temp;
|
||||
|
||||
if(input.lengthOf() == 1 || firstDim == 1) {
|
||||
|
||||
if(!isInplace)
|
||||
output.assign(input);
|
||||
}
|
||||
else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) {
|
||||
else if (shape::isCommonVector(input.shapeInfo(), temp)) {
|
||||
|
||||
// apply Fisher-Yates shuffle
|
||||
if(isInplace) {
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
|
||||
for(int i = firstDim-1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
if(i == r)
|
||||
continue;
|
||||
T t0 = input.t<T>(i);
|
||||
T t1 = input.t<T>(r);
|
||||
//math::nd4j_swap<T>(input(i), input(r));
|
||||
input.r<T>(i) = t1;
|
||||
input.r<T>(r) = t0;
|
||||
}
|
||||
NDArray* arr = &input;
|
||||
|
||||
if (!isInplace) {
|
||||
output.assign(input);
|
||||
arr = &output;
|
||||
}
|
||||
else {
|
||||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
output.p<T>(Nd4jLong(0), input.e<T>(0));
|
||||
|
||||
// FIXME: parallelism!!
|
||||
for(int i = firstDim-1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
output.r<T>(i) = input.t<T>(indices[r]);
|
||||
if(i == r)
|
||||
continue;
|
||||
const Nd4jLong ews = arr->ews();
|
||||
|
||||
output.r<T>(r) = input.t<T>(indices[i]);
|
||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||
const Nd4jLong len = arr->lengthOf();
|
||||
const Nd4jLong threshold = 1<<22; // this number was deduced from diagram in article
|
||||
|
||||
int power = 0;
|
||||
while ((len >> power) > threshold)
|
||||
++power;
|
||||
|
||||
const Nd4jLong numChunks = 1 << power;
|
||||
|
||||
auto funcFisherYates = PRAGMA_THREADS_FOR {
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
Nd4jLong offset = (len * i) >> power;
|
||||
Nd4jLong currLen = ((len * (i + 1)) >> power) - offset;
|
||||
fisherYates<T>(rng, arr->bufferAsT<T>() + offset*ews, currLen, ews, offset);
|
||||
}
|
||||
rng.rewindH(firstDim-1);
|
||||
}
|
||||
};
|
||||
|
||||
auto funcMerge = PRAGMA_THREADS_FOR {
|
||||
|
||||
for (int64_t i = start, k = 1; i < stop; i += increment, ++k) {
|
||||
Nd4jLong offset = len * i >> power;
|
||||
Nd4jLong len1 = (len * (i + increment/2) >> power) - offset;
|
||||
Nd4jLong totLen = (len * (i + increment) >> power) - offset;
|
||||
mergeShuffle<T>(rng, arr->bufferAsT<T>() + offset*ews, len1, totLen, ews, len * k + offset);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(funcFisherYates, 0, numChunks);
|
||||
|
||||
for (int j = 1; j < numChunks; j += j)
|
||||
samediff::Threads::parallel_for(funcMerge, 0, numChunks, 2*j);
|
||||
|
||||
// #pragma omp parallel for
|
||||
// for (uint i = 0; i < numChunks; ++i) {
|
||||
|
||||
// Nd4jLong offset = (len * i) >> power;
|
||||
// Nd4jLong currLen = ((len * (i + 1)) >> power) - offset;
|
||||
// fisherYates<T>(rng, arr->bufferAsT<T>() + offset*ews, currLen, ews, offset);
|
||||
// }
|
||||
|
||||
// for (uint j = 1; j < numChunks; j += j) {
|
||||
// #pragma omp parallel for
|
||||
// for (auto i = 0; i < numChunks; i += 2*j) {
|
||||
// Nd4jLong offset = len * i >> power;
|
||||
// Nd4jLong len1 = (len * (i + j) >> power) - offset;
|
||||
// Nd4jLong totLen = (len * (i + 2*j) >> power) - offset;
|
||||
// mergeShuffle(rng, arr->bufferAsT<T>() + offset*ews, len1, totLen, ews, len * j + offset);
|
||||
// }
|
||||
// }
|
||||
|
||||
rng.rewindH((len + 1) * power);
|
||||
}
|
||||
else {
|
||||
|
||||
// evaluate sub-arrays list of input array through all dimensions excluding first one
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||
auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
|
||||
auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||
|
||||
// apply Fisher-Yates shuffle
|
||||
if(isInplace) {
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().elementwiseThreshold())
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
|
||||
if(i == r)
|
||||
continue;
|
||||
subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r));
|
||||
auto subArrsList = input.allTensorsAlongDimension(dimsToExclude);
|
||||
|
||||
// Fisher-Yates shuffle
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
const int j = rng.relativeInt(i) % (i + 1);
|
||||
if(i != j)
|
||||
subArrsList.at(i)->swapUnsafe(*subArrsList.at(j));
|
||||
}
|
||||
}
|
||||
else {
|
||||
// evaluate sub-arrays list of output array through all dimensions excluding first one
|
||||
auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
|
||||
|
||||
auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude);
|
||||
auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude);
|
||||
|
||||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
bool isZeroShuffled = false;
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r]));
|
||||
if(r == 0)
|
||||
isZeroShuffled = true;
|
||||
if(i == r)
|
||||
continue;
|
||||
subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i]));
|
||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||
}
|
||||
if(!isZeroShuffled)
|
||||
subArrsListOut.at(0)->assign(subArrsListIn.at(0));
|
||||
std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1
|
||||
|
||||
// shuffle indices
|
||||
fisherYates<int>(rng, indices.data(), firstDim, 1, 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
|
||||
for (auto i = start; i < stop; ++i)
|
||||
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i]));
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, firstDim);
|
||||
}
|
||||
|
||||
rng.rewindH(firstDim-1);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
|
||||
}
|
||||
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, const
|
|||
|
||||
int coords[MAX_RANK];
|
||||
|
||||
for (uint64_t i = tid; i < zLen; i += totalThreads) {
|
||||
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||
shape::index2coords(i, zShapeInfo, coords);
|
||||
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, coords);
|
||||
|
@ -162,9 +162,9 @@ void concat(sd::LaunchContext * context, const std::vector<const NDArray*>& inAr
|
|||
// }
|
||||
// else { // general (slower) case
|
||||
|
||||
const int threadsPerBlock = 256;
|
||||
const int blocksPerGrid = 512;
|
||||
const int sharedMem = 512;
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
const int sharedMem = 256;
|
||||
|
||||
// prepare arrays of pointers on buffers and shapes
|
||||
std::vector<const void*> hInBuffers(numOfInArrs);
|
||||
|
|
|
@ -88,7 +88,7 @@ namespace helpers {
|
|||
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
|
||||
NDArray::prepareSpecialUse({values, indices}, {input});
|
||||
auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType();
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES, INDEXING_TYPES);
|
||||
NDArray::registerSpecialUse({values, indices}, {input});
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
// implemented algorithm is GPU adaptation of algorithm described in following article:
|
||||
// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167
|
||||
//
|
||||
|
||||
#include<ops/declarable/helpers/transforms.h>
|
||||
#include <array/ResultSet.h>
|
||||
#include <numeric>
|
||||
#include <execution/Threads.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <helpers/PointersManager.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static __global__ void fisherYatesCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power) {
|
||||
|
||||
T* x = reinterpret_cast<T*>(vx);
|
||||
|
||||
__shared__ T* shmem, temp;
|
||||
__shared__ Nd4jLong ind, blockOffset, lenPerBlock;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ unsigned char sharedMemory[];
|
||||
shmem = reinterpret_cast<T*>(sharedMemory);
|
||||
|
||||
blockOffset = (len * blockIdx.x) >> power;
|
||||
lenPerBlock = ((len * (blockIdx.x + 1)) >> power) - blockOffset;
|
||||
ind = blockOffset;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// copy from global memory to shared memory
|
||||
if(threadIdx.x < lenPerBlock)
|
||||
shmem[threadIdx.x] = x[(blockOffset + threadIdx.x) * ews];
|
||||
__syncthreads();
|
||||
|
||||
// *** apply Fisher-Yates shuffle to lenPerBlock number of elements
|
||||
if (threadIdx.x == 0) {
|
||||
for(Nd4jLong i = lenPerBlock - 1; i > 0; --i) {
|
||||
const Nd4jLong j = rng->relativeLong(ind++) % (i + 1);
|
||||
if(i != j) {
|
||||
temp = shmem[i];
|
||||
shmem[i] = shmem[j];
|
||||
shmem[j] = temp;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// copy from shared memory to global memory
|
||||
if(threadIdx.x < lenPerBlock)
|
||||
x[(blockOffset + threadIdx.x) * ews] = shmem[threadIdx.x];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void mergeShuffleCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power, const Nd4jLong iterNum) {
|
||||
|
||||
|
||||
T* x = reinterpret_cast<T*>(vx);
|
||||
|
||||
__shared__ Nd4jLong ind, blockOffset, factor, beg, mid, totLen, iterExp;
|
||||
|
||||
// *** apply mergeShuffle algorithm
|
||||
if(threadIdx.x == 0) {
|
||||
|
||||
factor = blockIdx.x << iterNum;
|
||||
iterExp = 1 << (iterNum - 1);
|
||||
blockOffset = (len * factor) >> power;
|
||||
mid = ((len * (factor + iterExp)) >> power) - blockOffset; // middle
|
||||
totLen = ((len * (factor + 2*iterExp)) >> power) - blockOffset;
|
||||
ind = iterNum * len + blockOffset;
|
||||
beg = 0; // beginning
|
||||
|
||||
// printf("m %lld, blockIdx.x %lld, factor %lld, blockOffset %lld, mid %lld, totLen %lld \n", m,k,factor,blockOffset,mid,totLen);
|
||||
|
||||
while (true) {
|
||||
if(rng->relativeLong(ind++) % 2) {
|
||||
if(mid == totLen)
|
||||
break;
|
||||
math::nd4j_swap<T>(x[(blockOffset + beg) * ews], x[(blockOffset + mid++) * ews]);
|
||||
} else {
|
||||
if(beg == mid)
|
||||
break;
|
||||
}
|
||||
++beg;
|
||||
}
|
||||
|
||||
// Fisher-Yates
|
||||
while (beg < totLen) {
|
||||
const Nd4jLong e = rng->relativeLong(ind++) % (beg + 1);
|
||||
if(beg != e)
|
||||
math::nd4j_swap<T>(x[(blockOffset + beg) * ews], x[(blockOffset + e) * ews]);
|
||||
++beg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// Fisher-Yates shuffle
|
||||
template <typename T>
|
||||
static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) {
|
||||
|
||||
for(Nd4jLong i = len-1; i > 0; --i) {
|
||||
const Nd4jLong j = rng.relativeLong(ind++) % (i + 1);
|
||||
if(i != j)
|
||||
math::nd4j_swap<T>(buff[i*ews], buff[j*ews]);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void randomShuffle_(sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
|
||||
const int firstDim = input.sizeAt(0);
|
||||
int temp;
|
||||
|
||||
if(input.lengthOf() == 1 || firstDim == 1) {
|
||||
|
||||
if(!isInplace)
|
||||
output.assign(input);
|
||||
}
|
||||
else if (shape::isCommonVector(input.shapeInfo(), temp)) {
|
||||
|
||||
NDArray* arr = &input;
|
||||
|
||||
if (!isInplace) {
|
||||
output.assign(input);
|
||||
arr = &output;
|
||||
}
|
||||
|
||||
const Nd4jLong len = arr->lengthOf();
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS;
|
||||
|
||||
int power = 0;
|
||||
while ((len >> power) > threadsPerBlock)
|
||||
++power;
|
||||
|
||||
const int blocksPerGrid = 1 << power;
|
||||
const int sharedMem = threadsPerBlock * input.sizeOfT() + 256;
|
||||
|
||||
PointersManager manager(context, "NDArray::randomShuffle cuda");
|
||||
|
||||
sd::graph::RandomGenerator* pRng = reinterpret_cast<sd::graph::RandomGenerator*>(manager.replicatePointer(&rng, sizeof(sd::graph::RandomGenerator)));
|
||||
|
||||
NDArray::prepareSpecialUse({arr}, {arr});
|
||||
fisherYatesCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *context->getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power);
|
||||
for (Nd4jLong j = 1, i = 1; j < blocksPerGrid; j += j, ++i)
|
||||
mergeShuffleCuda<T><<<blocksPerGrid/(2*j), threadsPerBlock, 256, *context->getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power, i);
|
||||
NDArray::registerSpecialUse({arr}, {arr});
|
||||
|
||||
manager.synchronize();
|
||||
|
||||
rng.rewindH((len + 1) * power);
|
||||
}
|
||||
else {
|
||||
|
||||
auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||
|
||||
if(isInplace) {
|
||||
|
||||
auto subArrsList = input.allTensorsAlongDimension(dimsToExclude);
|
||||
|
||||
// Fisher-Yates shuffle
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
const int j = rng.relativeInt(i) % (i + 1);
|
||||
if(i != j)
|
||||
subArrsList.at(i)->swapUnsafe(*subArrsList.at(j));
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude);
|
||||
auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude);
|
||||
|
||||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1
|
||||
|
||||
// shuffle indices
|
||||
fisherYates<int>(rng, indices.data(), firstDim, 1, 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
|
||||
for (auto i = start; i < stop; ++i)
|
||||
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i]));
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, firstDim);
|
||||
}
|
||||
|
||||
rng.rewindH(firstDim-1);
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
// BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -300,129 +300,6 @@ void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray
|
|||
manager.synchronize();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) {
|
||||
auto tid = blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
|
||||
int r = rng->relativeInt(i) % i;
|
||||
if (i != r) {
|
||||
const auto iOffset = shape::getIndexOffset(i, shape);
|
||||
const auto rOffset = shape::getIndexOffset(r, shape);
|
||||
T e0 = input[iOffset];
|
||||
T e1 = input[rOffset];
|
||||
//math::nd4j_swap<T>(input(i), input(r));
|
||||
input[iOffset] = e1;
|
||||
input[rOffset] = e0;
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
static __global__ void fillShuffleKernel(T* input, Nd4jLong const* inputShape, T* output, Nd4jLong const* outputShape, Nd4jLong firstDim, int* indices, sd::graph::RandomGenerator* rng) {
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
|
||||
auto tid = blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for(int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
|
||||
int r = rng->relativeInt(i) % i;
|
||||
output[shape::getIndexOffset(i, outputShape)] = input[shape::getIndexOffset(indices[r], inputShape)];
|
||||
if(i != r) {
|
||||
output[shape::getIndexOffset(r, outputShape)] = input[shape::getIndexOffset(indices[i], inputShape)];
|
||||
// output.p(r, input.e<T>(indices[i]));
|
||||
// math::nd4j_swap<int>(indices[i], indices[r]);
|
||||
atomicExch(&indices[i], indices[r]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void randomShuffle_(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
|
||||
// check edge cases first
|
||||
int temp;
|
||||
const int firstDim = input.sizeAt(0);
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({&output}, {&input});
|
||||
if(input.lengthOf() == 1 || firstDim == 1) {
|
||||
if(!isInplace)
|
||||
output.assign(input);
|
||||
}
|
||||
else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) {
|
||||
|
||||
// apply Fisher-Yates shuffle
|
||||
sd::graph::RandomGenerator* dRandom = nullptr;
|
||||
cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator));
|
||||
cudaMemcpy(dRandom, &rng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice);
|
||||
T* inputBuf = reinterpret_cast<T*>(input.specialBuffer());
|
||||
if(isInplace) {
|
||||
swapShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), firstDim, dRandom);
|
||||
}
|
||||
else {
|
||||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), cudaMemcpyDeviceToDevice);
|
||||
//output.p<T>(Nd4jLong(0), input.e<T>(0));
|
||||
PointersManager pointersManager(context, "helper::randomShuffle_");
|
||||
int* indicesDev = reinterpret_cast<int*>(pointersManager.replicatePointer(indices.data(), indices.size() * sizeof(int)));
|
||||
T* outputBuf = reinterpret_cast<T*>(output.specialBuffer());
|
||||
fillShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), outputBuf, output.specialShapeInfo(), firstDim, indicesDev, dRandom);
|
||||
pointersManager.synchronize();
|
||||
}
|
||||
// rng.rewindH(firstDim - 1);
|
||||
cudaFree(dRandom);
|
||||
}
|
||||
else {
|
||||
|
||||
// evaluate sub-arrays list of input array through all dimensions excluding first one
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||
auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
|
||||
|
||||
// apply Fisher-Yates shuffle
|
||||
if(isInplace) {
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
|
||||
if(i != r)
|
||||
subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r));
|
||||
}
|
||||
}
|
||||
else {
|
||||
// evaluate sub-arrays list of output array through all dimensions excluding first one
|
||||
auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
|
||||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
bool isZeroShuffled = false;
|
||||
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r]));
|
||||
if(r == 0)
|
||||
isZeroShuffled = true;
|
||||
|
||||
if(i != r) {
|
||||
subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i]));
|
||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||
}
|
||||
}
|
||||
if(!isZeroShuffled)
|
||||
subArrsListOut.at(0)->assign(subArrsListIn.at(0));
|
||||
}
|
||||
rng.rewindH(firstDim-1);
|
||||
}
|
||||
NDArray::registerSpecialUse({&output}, {&input});
|
||||
|
||||
}
|
||||
|
||||
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void eye(sd::LaunchContext * context, NDArray& output) {
|
||||
|
||||
|
|
|
@ -0,0 +1,278 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
// Created by Abdelrauf 2020
|
||||
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
#include <ops/declarable/OpRegistrator.h>
|
||||
#include <system/platform_boilerplate.h>
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
#include <cstdint>
|
||||
#include <helpers/LoopsCoordsHelper.h>
|
||||
|
||||
#include "armcomputeUtils.h"
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
|
||||
|
||||
Arm_DataType getArmType ( const DataType &dType){
|
||||
Arm_DataType ret;
|
||||
switch (dType){
|
||||
case HALF :
|
||||
ret = Arm_DataType::F16;
|
||||
break;
|
||||
case FLOAT32 :
|
||||
ret = Arm_DataType::F32;
|
||||
break;
|
||||
case DOUBLE :
|
||||
ret = Arm_DataType::F64;
|
||||
break;
|
||||
case INT8 :
|
||||
ret = Arm_DataType::S8;
|
||||
break;
|
||||
case INT16 :
|
||||
ret = Arm_DataType::S16;
|
||||
break;
|
||||
case INT32 :
|
||||
ret = Arm_DataType::S32;
|
||||
break;
|
||||
case INT64 :
|
||||
ret = Arm_DataType::S64;
|
||||
break;
|
||||
case UINT8 :
|
||||
ret = Arm_DataType::U8;
|
||||
break;
|
||||
case UINT16 :
|
||||
ret = Arm_DataType::U16;
|
||||
break;
|
||||
case UINT32 :
|
||||
ret = Arm_DataType::U32;
|
||||
break;
|
||||
case UINT64 :
|
||||
ret = Arm_DataType::U64;
|
||||
break;
|
||||
case BFLOAT16 :
|
||||
ret = Arm_DataType::BFLOAT16;
|
||||
break;
|
||||
default:
|
||||
ret = Arm_DataType::UNKNOWN;
|
||||
};
|
||||
|
||||
return ret;
|
||||
}
|
||||
bool isArmcomputeFriendly(const NDArray& arr) {
|
||||
auto dType = getArmType(arr.dataType());
|
||||
int rank = (int)(arr.rankOf());
|
||||
return dType != Arm_DataType::UNKNOWN &&
|
||||
rank<=arm_compute::MAX_DIMS &&
|
||||
arr.ordering() == 'c' &&
|
||||
arr.ews()==1 &&
|
||||
shape::strideDescendingCAscendingF(arr.shapeInfo()) == true;
|
||||
}
|
||||
|
||||
Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases,sd::DataType ndArrayType, arm_compute::DataLayout layout) {
|
||||
constexpr int numChannels = 1;
|
||||
auto dType = getArmType(ndArrayType);
|
||||
|
||||
Arm_TensorShape shape;
|
||||
shape.set_num_dimensions(rank);
|
||||
for (int i = 0, j = rank - 1; i < rank; i++, j--) {
|
||||
shape[i] = static_cast<uint32_t>(bases[j]);
|
||||
}
|
||||
// fill the rest unused with 1
|
||||
for (int i = rank; i < arm_compute::MAX_DIMS; i++) {
|
||||
shape[i] = 1;
|
||||
}
|
||||
|
||||
return Arm_TensorInfo(shape, numChannels, dType, layout);
|
||||
}
|
||||
|
||||
Arm_TensorInfo getArmTensorInfo(const NDArray& arr,
|
||||
arm_compute::DataLayout layout) {
|
||||
auto dType = getArmType(arr.dataType());
|
||||
|
||||
//
|
||||
constexpr int numChannels = 1;
|
||||
int rank = (int)(arr.rankOf());
|
||||
auto bases = arr.shapeOf();
|
||||
auto arrStrides = arr.stridesOf();
|
||||
|
||||
// https://arm-software.github.io/ComputeLibrary/v20.05/_dimensions_8h_source.xhtml
|
||||
// note: underhood it is stored as std::array<T, num_max_dimensions> _id;
|
||||
// TensorShape is derived from Dimensions<uint32_t>
|
||||
// as well as Strides : public Dimensions<uint32_t>
|
||||
Arm_TensorShape shape;
|
||||
Arm_Strides strides;
|
||||
shape.set_num_dimensions(rank);
|
||||
strides.set_num_dimensions(rank);
|
||||
size_t element_size = arm_compute::data_size_from_type(dType);
|
||||
for (int i = 0, j = rank - 1; i < rank; i++, j--) {
|
||||
shape[i] = static_cast<uint32_t>(bases[j]);
|
||||
strides[i] = static_cast<uint32_t>(arrStrides[j]) * element_size;
|
||||
}
|
||||
// fill the rest unused with 1
|
||||
for (int i = rank; i < arm_compute::MAX_DIMS; i++) {
|
||||
shape[i] = 1;
|
||||
}
|
||||
size_t total_size;
|
||||
size_t size_ind = rank - 1;
|
||||
total_size = shape[size_ind] * strides[size_ind];
|
||||
|
||||
Arm_TensorInfo info;
|
||||
info.init(shape, numChannels, dType, strides, 0, total_size);
|
||||
info.set_data_layout(layout);
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout) {
|
||||
// - Ownership of the backing memory is not transferred to the tensor itself.
|
||||
// - The tensor mustn't be memory managed.
|
||||
// - Padding requirements should be accounted by the client code.
|
||||
// In other words, if padding is required by the tensor after the function
|
||||
// configuration step, then the imported backing memory should account for it.
|
||||
// Padding can be checked through the TensorInfo::padding() interface.
|
||||
|
||||
// Import existing pointer as backing memory
|
||||
auto info = getArmTensorInfo(arr, layout);
|
||||
Arm_Tensor tensor;
|
||||
tensor.allocator()->init(info);
|
||||
void* buff = (void*)arr.buffer();
|
||||
tensor.allocator()->import_memory(buff);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output) {
|
||||
//only for C order
|
||||
//only for C order
|
||||
if (output.ordering() != 'c') return;
|
||||
auto shapeInfo = output.shapeInfo();
|
||||
auto bases = &(shapeInfo[1]);
|
||||
Nd4jLong rank = shapeInfo[0];
|
||||
auto strides = output.stridesOf();
|
||||
int width = bases[rank - 1];
|
||||
uint8_t* outputBuffer = (uint8_t*)output.buffer();
|
||||
size_t offset = 0;
|
||||
arm_compute::Window window;
|
||||
arm_compute::Iterator tensor_it(&inTensor, window);
|
||||
|
||||
int element_size = inTensor.info()->element_size();
|
||||
window.use_tensor_dimensions(inTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY);
|
||||
|
||||
// if (output.ews() == 1) {
|
||||
auto copySize = width * element_size;
|
||||
auto dest = outputBuffer;
|
||||
arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
|
||||
{
|
||||
auto src = tensor_it.ptr();
|
||||
memcpy(dest, src, copySize);
|
||||
dest += copySize;
|
||||
},
|
||||
tensor_it);
|
||||
// }
|
||||
// else {
|
||||
// Nd4jLong coords[MAX_RANK] = {};
|
||||
// if(strides[rank-1]!=1){
|
||||
// throw std::runtime_error( "not implemented for subarrays whose last stride is not 1");
|
||||
// //TODO: implement to work with all subarrays properly
|
||||
// }
|
||||
// arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
|
||||
// {
|
||||
// auto src = tensor_it.ptr();
|
||||
// auto dest = outputBuffer + offset * element_size;
|
||||
// memcpy(dest, src, width * element_size);
|
||||
// offset = sd::inc_coords(bases, strides, coords, offset, rank, 1);
|
||||
// },
|
||||
// tensor_it);
|
||||
// }
|
||||
}
|
||||
|
||||
void copyToTensor(const NDArray& input, Arm_Tensor& outTensor) {
|
||||
//only for C order
|
||||
if (input.ordering() != 'c') return;
|
||||
auto shapeInfo = input.shapeInfo();
|
||||
auto bases = &(shapeInfo[1]);
|
||||
Nd4jLong rank = shapeInfo[0];
|
||||
auto strides = input.stridesOf();
|
||||
uint8_t *inputBuffer = (uint8_t*)input.buffer();
|
||||
int width = bases[rank - 1];
|
||||
size_t offset = 0;
|
||||
arm_compute::Window window;
|
||||
arm_compute::Iterator tensor_it(&outTensor, window);
|
||||
int element_size = outTensor.info()->element_size();
|
||||
|
||||
window.use_tensor_dimensions(outTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY);
|
||||
|
||||
// if (input.ews() == 1) {
|
||||
|
||||
auto copySize = width * element_size;
|
||||
auto src = inputBuffer;
|
||||
arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
|
||||
{
|
||||
auto dest = tensor_it.ptr();
|
||||
memcpy(dest,src, copySize);
|
||||
src += copySize;
|
||||
},
|
||||
tensor_it);
|
||||
// }
|
||||
// else {
|
||||
// Nd4jLong coords[MAX_RANK] = {};
|
||||
// if(strides[rank-1]!=1){
|
||||
// throw std::runtime_error( "not implemented for subarrays whose last stride is not 1");
|
||||
// //TODO: implement to work with all subarrays properly
|
||||
// }
|
||||
// arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
|
||||
// {
|
||||
// auto dest = tensor_it.ptr();
|
||||
// auto src = inputBuffer + offset * element_size;
|
||||
// offset = sd::inc_coords(bases, strides, coords, offset, rank, 1);
|
||||
// },
|
||||
// tensor_it);
|
||||
// }
|
||||
}
|
||||
|
||||
|
||||
// armcompute should be built with debug option
|
||||
void print_tensor(Arm_ITensor& tensor, const char* msg) {
|
||||
auto info = tensor.info();
|
||||
auto padding = info->padding();
|
||||
std::cout << msg << "\ntotal: " << info->total_size() << "\n";
|
||||
|
||||
for (int i = 0; i < arm_compute::MAX_DIMS; i++) {
|
||||
std::cout << info->dimension(i) << ",";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
for (int i = 0; i < arm_compute::MAX_DIMS; i++) {
|
||||
std::cout << info->strides_in_bytes()[i] << ",";
|
||||
}
|
||||
std::cout << "\npadding: l " << padding.left << ", r " << padding.right
|
||||
<< ", t " << padding.top << ", b " << padding.bottom << std::endl;
|
||||
|
||||
#ifdef ARM_COMPUTE_ASSERTS_ENABLED
|
||||
//note it did not print correctly fro NHWC
|
||||
std::cout << msg << ":\n";
|
||||
tensor.print(std::cout);
|
||||
std::cout << std::endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
#ifndef DEV_TESTSARMCOMPUTEUTILS_H
|
||||
#define DEV_TESTSARMCOMPUTEUTILS_H
|
||||
|
||||
|
||||
#include <legacy/NativeOps.h>
|
||||
#include <array/NDArray.h>
|
||||
#include <graph/Context.h>
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
#include <system/platform_boilerplate.h>
|
||||
#include <arm_compute/runtime/NEON/NEFunctions.h>
|
||||
#include <arm_compute/core/Types.h>
|
||||
#include <arm_compute/core/TensorInfo.h>
|
||||
#include <arm_compute/core/TensorShape.h>
|
||||
#include <arm_compute/core/Strides.h>
|
||||
#include <arm_compute/core/Helpers.h>
|
||||
#include <arm_compute/core/ITensor.h>
|
||||
#include <arm_compute/core/Types.h>
|
||||
#include <arm_compute/core/Validate.h>
|
||||
#include <arm_compute/core/Window.h>
|
||||
#include <arm_compute/runtime/Tensor.h>
|
||||
#include <arm_compute/runtime/TensorAllocator.h>
|
||||
#include <iostream>
|
||||
|
||||
using namespace samediff;
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
using Arm_DataType = arm_compute::DataType;
|
||||
using Arm_Tensor = arm_compute::Tensor;
|
||||
using Arm_ITensor = arm_compute::ITensor;
|
||||
using Arm_TensorInfo = arm_compute::TensorInfo;
|
||||
using Arm_TensorShape = arm_compute::TensorShape;
|
||||
using Arm_Strides = arm_compute::Strides;
|
||||
/**
|
||||
* Here we actually declare our platform helpers
|
||||
*/
|
||||
|
||||
|
||||
DECLARE_PLATFORM(maxpool2d, ENGINE_CPU);
|
||||
|
||||
DECLARE_PLATFORM(avgpool2d, ENGINE_CPU);
|
||||
|
||||
//utils
|
||||
Arm_DataType getArmType(const sd::DataType& dType);
|
||||
|
||||
Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases, sd::DataType ndArrayType, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN);
|
||||
|
||||
Arm_TensorInfo getArmTensorInfo(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN);
|
||||
|
||||
Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN);
|
||||
|
||||
void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output);
|
||||
void copyToTensor(const NDArray& input, Arm_Tensor& outTensor);
|
||||
void print_tensor(Arm_ITensor& tensor, const char* msg);
|
||||
bool isArmcomputeFriendly(const NDArray& arr);
|
||||
|
||||
|
||||
template<typename F>
|
||||
class ArmFunction {
|
||||
public:
|
||||
|
||||
template<typename ...Args>
|
||||
void configure(NDArray *input , NDArray *output, arm_compute::DataLayout layout, Args&& ...args) {
|
||||
|
||||
auto inInfo = getArmTensorInfo(*input, layout);
|
||||
auto outInfo = getArmTensorInfo(*output, layout);
|
||||
in.allocator()->init(inInfo);
|
||||
out.allocator()->init(outInfo);
|
||||
armFunction.configure(&in,&out,std::forward<Args>(args) ...);
|
||||
if (in.info()->has_padding()) {
|
||||
//allocate and copy
|
||||
in.allocator()->allocate();
|
||||
//copy
|
||||
copyToTensor(*input, in);
|
||||
|
||||
}
|
||||
else {
|
||||
//import buffer
|
||||
void* buff = input->buffer();
|
||||
in.allocator()->import_memory(buff);
|
||||
}
|
||||
if (out.info()->has_padding()) {
|
||||
//store pointer to our array to copy after run
|
||||
out.allocator()->allocate();
|
||||
outNd = output;
|
||||
}
|
||||
else {
|
||||
//import
|
||||
void* buff = output->buffer();
|
||||
out.allocator()->import_memory(buff);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void run() {
|
||||
armFunction.run();
|
||||
if (outNd) {
|
||||
copyFromTensor(out, *outNd);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Arm_Tensor in;
|
||||
Arm_Tensor out;
|
||||
NDArray *outNd=nullptr;
|
||||
F armFunction{};
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#endif //DEV_TESTSARMCOMPUTEUTILS_H
|
|
@ -0,0 +1,106 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
// Created by Abdelrauf (rauf@konduit.ai) 2020
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
#include <ops/declarable/OpRegistrator.h>
|
||||
#include <system/platform_boilerplate.h>
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
|
||||
#include "armcomputeUtils.h"
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
|
||||
const auto kH = INT_ARG(0);
|
||||
const auto kW = INT_ARG(1);
|
||||
const auto sH = INT_ARG(2);
|
||||
const auto sW = INT_ARG(3);
|
||||
auto pH = INT_ARG(4);
|
||||
auto pW = INT_ARG(5);
|
||||
const auto dH = INT_ARG(6);
|
||||
const auto dW = INT_ARG(7);
|
||||
const auto paddingMode = INT_ARG(8);
|
||||
const auto extraParam0 = INT_ARG(9);
|
||||
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D ARMCOMPUTE op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D ARMCOMPUTE op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
bool exclude_padding= (extraParam0 == 0) ? true : false;
|
||||
|
||||
auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC;
|
||||
|
||||
// Calculate individual paddings
|
||||
unsigned int pad_left, pad_top, pad_right, pad_bottom;
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
if(paddingMode){
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
}
|
||||
pad_left = pW;
|
||||
pad_top = pH;
|
||||
pad_right = (oW - 1) * sW - iW + kW - pW ;
|
||||
pad_bottom = (oH - 1) * sH - iH + kH - pH ;
|
||||
|
||||
#if 0
|
||||
nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH
|
||||
, pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0);
|
||||
#endif
|
||||
auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR);
|
||||
auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::AVG, arm_compute::Size2D(kW, kH), dataLayout, poolPad, exclude_padding);
|
||||
ArmFunction<arm_compute::NEPoolingLayer> pool;
|
||||
pool.configure(input,output, dataLayout, poolInfo);
|
||||
|
||||
pool.run(); // run function
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
const int dH = INT_ARG(6);
|
||||
const int dW = INT_ARG(7);
|
||||
// Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32
|
||||
auto dTypeInput = getArmType(input->dataType());
|
||||
auto dTypeOutput = getArmType(output->dataType());
|
||||
bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output)
|
||||
&& (dTypeInput ==Arm_DataType::F32)
|
||||
&& (dTypeOutput ==Arm_DataType::F32);
|
||||
return is_supported;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
// Created by Abdelrauf 2020
|
||||
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
#include <ops/declarable/OpRegistrator.h>
|
||||
#include <system/platform_boilerplate.h>
|
||||
#include <ops/declarable/helpers/convolutions.h>
|
||||
|
||||
|
||||
#include "armcomputeUtils.h"
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D ARMCOMPUTE OP: input array should have rank of 4, but got %i instead", input->rankOf());
|
||||
|
||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||
const int kH = INT_ARG(0);
|
||||
const int kW = INT_ARG(1);
|
||||
const int sH = INT_ARG(2);
|
||||
const int sW = INT_ARG(3);
|
||||
int pH = INT_ARG(4);
|
||||
int pW = INT_ARG(5);
|
||||
const int dH = INT_ARG(6);
|
||||
const int dW = INT_ARG(7);
|
||||
const int paddingMode = INT_ARG(8);
|
||||
// const int extraParam0 = INT_ARG(9);
|
||||
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW
|
||||
|
||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||
|
||||
auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC;
|
||||
|
||||
// Calculate individual paddings
|
||||
unsigned int pad_left, pad_top, pad_right, pad_bottom;
|
||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||
|
||||
if(paddingMode){
|
||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||
}
|
||||
pad_left = pW;
|
||||
pad_top = pH;
|
||||
pad_right = (oW - 1) * sW - iW + kW - pW ;
|
||||
pad_bottom = (oH - 1) * sH - iH + kH - pH ;
|
||||
#if 0
|
||||
nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH
|
||||
, pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0);
|
||||
#endif
|
||||
|
||||
auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR);
|
||||
auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::MAX, arm_compute::Size2D(kW, kH), dataLayout, poolPad);
|
||||
ArmFunction<arm_compute::NEPoolingLayer> pool;
|
||||
|
||||
pool.configure(input,output, dataLayout, poolInfo);
|
||||
|
||||
pool.run(); // run function
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(maxpool2d, ENGINE_CPU) {
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
const int dH = INT_ARG(6);
|
||||
const int dW = INT_ARG(7);
|
||||
// Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32
|
||||
auto dTypeInput = getArmType(input->dataType());
|
||||
auto dTypeOutput = getArmType(output->dataType());
|
||||
bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output)
|
||||
&& (dTypeInput ==Arm_DataType::F32)
|
||||
&& (dTypeOutput ==Arm_DataType::F32);
|
||||
return is_supported;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -3963,9 +3963,6 @@ namespace simdOps {
|
|||
}
|
||||
#endif
|
||||
|
||||
#ifndef __clang__
|
||||
#pragma omp declare simd uniform(extraParamsRef)
|
||||
#endif
|
||||
op_def static Y merge(X old, X opOutput, X *extraParamsRef) {
|
||||
return update(old, opOutput, extraParamsRef);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,185 @@
|
|||
#!/bin/bash
|
||||
TARGET=armv7-a
|
||||
BLAS_TARGET_NAME=ARMV7
|
||||
ARMCOMPUTE_TARGET=armv7a
|
||||
#BASE_DIR=${HOME}/pi
|
||||
#https://stackoverflow.com/questions/59895/how-to-get-the-source-directory-of-a-bash-script-from-within-the-script-itself
|
||||
SOURCE="${BASH_SOURCE[0]}"
|
||||
ARMCOMPUTE_DEBUG=1
|
||||
LIBND4J_BUILD_MODE=Release
|
||||
while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
|
||||
DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )"
|
||||
SOURCE="$(readlink "$SOURCE")"
|
||||
[[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
|
||||
done
|
||||
BASE_DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )"
|
||||
CMAKE=cmake #/snap/bin/cmake
|
||||
|
||||
mkdir -p ${BASE_DIR}/helper_bin/
|
||||
|
||||
CROSS_COMPILER_URL=https://sourceforge.net/projects/raspberry-pi-cross-compilers/files/Raspberry%20Pi%20GCC%20Cross-Compiler%20Toolchains/Buster/GCC%208.3.0/Raspberry%20Pi%203A%2B%2C%203B%2B%2C%204/cross-gcc-8.3.0-pi_3%2B.tar.gz/download
|
||||
CROSS_COMPILER_DIR=${BASE_DIR}/helper_bin/cross_compiler
|
||||
|
||||
SCONS_LOCAL_URL=http://prdownloads.sourceforge.net/scons/scons-local-3.1.1.tar.gz
|
||||
SCONS_LOCAL_DIR=${BASE_DIR}/helper_bin/scons_local
|
||||
|
||||
THIRD_PARTY=${BASE_DIR}/third_party_libs
|
||||
|
||||
ARMCOMPUTE_GIT_URL=https://github.com/ARM-software/ComputeLibrary.git
|
||||
ARMCOMPUTE_TAG=v20.05
|
||||
ARMCOMPUTE_DIR=${THIRD_PARTY}/arm_compute_dir
|
||||
|
||||
OPENBLAS_GIT_URL="https://github.com/xianyi/OpenBLAS.git"
|
||||
OPENBLAS_DIR=${THIRD_PARTY}/OpenBLAS
|
||||
|
||||
|
||||
LIBND4J_SRC_DIR=${BASE_DIR}
|
||||
|
||||
LIBND4J_BUILD_DIR=${BASE_DIR}/build_pi
|
||||
|
||||
#for some downloads
|
||||
XRTACT_STRIP="--strip-components=1"
|
||||
|
||||
HAS_ARMCOMPUTE=1
|
||||
mkdir -p ${BASE_DIR}
|
||||
mkdir -p ${THIRD_PARTY}
|
||||
|
||||
#change directory to base
|
||||
cd $BASE_DIR
|
||||
|
||||
function message {
|
||||
echo "BUILDER:::: ${@}"
|
||||
}
|
||||
|
||||
|
||||
function check_requirements {
|
||||
for i in "${@}"
|
||||
do
|
||||
if [ ! -e "$i" ]; then
|
||||
message "missing: ${i}"
|
||||
exit -2
|
||||
fi
|
||||
done
|
||||
}
|
||||
|
||||
function download_extract {
|
||||
#$1 is url #2 is dir $3 is extract argument
|
||||
if [ ! -f ${2}_file ]; then
|
||||
message "download"
|
||||
wget --quiet --show-progress -O ${2}_file ${1}
|
||||
fi
|
||||
|
||||
message "extract"
|
||||
#extract
|
||||
mkdir -p ${2}
|
||||
command="tar -xzf ${2}_file --directory=${2} ${3} "
|
||||
message $command
|
||||
$command
|
||||
|
||||
check_requirements "${2}"
|
||||
}
|
||||
|
||||
function git_check {
|
||||
#$1 is url #$2 is dir #$3 is tag or branch if optional
|
||||
command="git clone --quiet ${1} ${2}"
|
||||
message "$command"
|
||||
$command
|
||||
if [ -n "$3" ]; then
|
||||
cd ${2}
|
||||
command="git checkout ${3}"
|
||||
message "$command"
|
||||
$command
|
||||
cd ${BASE_DIR}
|
||||
fi
|
||||
check_requirements "${2}"
|
||||
}
|
||||
|
||||
|
||||
if [ ! -d ${CROSS_COMPILER_DIR} ]; then
|
||||
#out file
|
||||
message "download CROSS_COMPILER"
|
||||
download_extract ${CROSS_COMPILER_URL} ${CROSS_COMPILER_DIR} ${XRTACT_STRIP}
|
||||
fi
|
||||
|
||||
#useful exports
|
||||
export PI_FOLDER=${CROSS_COMPILER_DIR}
|
||||
export RPI_BIN=${PI_FOLDER}/bin/arm-linux-gnueabihf
|
||||
export PI_SYS_ROOT=${PI_FOLDER}/arm-linux-gnueabihf/libc
|
||||
export LD_LIBRARY_PATH=${PI_FOLDER}/lib:$LD_LIBRARY_PATH
|
||||
export CC=${RPI_BIN}-gcc
|
||||
export FC=${RPI_BIN}-gfortran
|
||||
export CXX=${RPI_BIN}-g++
|
||||
export CPP=${RPI_BIN}-cpp
|
||||
export RANLIB=${RPI_BIN}-gcc-ranlib
|
||||
export LD="${RPI_BIN}-ld"
|
||||
export AR="${RPI_BIN}-ar"
|
||||
|
||||
|
||||
#lets build OpenBlas
|
||||
if [ ! -d "${OPENBLAS_DIR}" ]; then
|
||||
message "download OpenBLAS"
|
||||
git_check "${OPENBLAS_GIT_URL}" "${OPENBLAS_DIR}"
|
||||
fi
|
||||
|
||||
if [ ! -f "${THIRD_PARTY}/lib/libopenblas.so" ]; then
|
||||
message "build and install OpenBLAS"
|
||||
cd ${OPENBLAS_DIR}
|
||||
|
||||
command="make TARGET=${BLAS_TARGET_NAME} HOSTCC=gcc CC=${CC} USE_THREAD=0 NOFORTRAN=1 CFLAGS=--sysroot=${PI_SYS_ROOT} LDFLAGS=\"-L${PI_SYS_ROOT}/../lib/ -lm\" &>/dev/null"
|
||||
message $command
|
||||
eval $command
|
||||
message "install it"
|
||||
command="make PREFIX=${THIRD_PARTY} install"
|
||||
message $command
|
||||
$command
|
||||
cd $BASE_DIR
|
||||
|
||||
fi
|
||||
check_requirements ${THIRD_PARTY}/lib/libopenblas.so
|
||||
|
||||
|
||||
|
||||
if [ ! -d ${SCONS_LOCAL_DIR} ]; then
|
||||
#out file
|
||||
message "download Scons local"
|
||||
download_extract ${SCONS_LOCAL_URL} ${SCONS_LOCAL_DIR}
|
||||
fi
|
||||
check_requirements ${SCONS_LOCAL_DIR}/scons.py
|
||||
|
||||
|
||||
if [ ! -d "${ARMCOMPUTE_DIR}" ]; then
|
||||
message "download ArmCompute Source"
|
||||
git_check ${ARMCOMPUTE_GIT_URL} "${ARMCOMPUTE_DIR}" "tags/${ARMCOMPUTE_TAG}"
|
||||
fi
|
||||
|
||||
#build armcompute
|
||||
if [ ! -f "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" ]; then
|
||||
message "build arm compute"
|
||||
cd ${ARMCOMPUTE_DIR}
|
||||
command="CC=gcc CXX=g++ python3 ${SCONS_LOCAL_DIR}/scons.py Werror=1 -j$(nproc) toolchain_prefix=${RPI_BIN}- debug=${ARMCOMPUTE_DEBUG} neon=1 opencl=0 extra_cxx_flags=-fPIC os=linux build=cross_compile arch=${ARMCOMPUTE_TARGET} &>/dev/null"
|
||||
message $command
|
||||
eval $command
|
||||
cd ${BASE_DIR}
|
||||
fi
|
||||
check_requirements "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" "${ARMCOMPUTE_DIR}/build/libarm_compute_core-static.a"
|
||||
|
||||
|
||||
|
||||
message "build cmake for LIBND4J. output: ${LIBND4J_BUILD_DIR}"
|
||||
|
||||
TOOLCHAIN=${LIBND4J_SRC_DIR}/cmake/rpi.cmake
|
||||
cmake_cmd="${CMAKE} -G \"Unix Makefiles\" -B${LIBND4J_BUILD_DIR} -S${LIBND4J_SRC_DIR} -DCMAKE_BUILD_TYPE=${LIBND4J_BUILD_MODE} -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN} -DCMAKE_VERBOSE_MAKEFILE:BOOL=ON -DSD_ALL_OPS=true -DSD_CPU=true -DSD_LIBRARY_NAME=nd4jcpu -DSD_BUILD_TESTS=ON -DSD_ARM_BUILD=true -DOPENBLAS_PATH=${THIRD_PARTY} -DSD_ARCH=${TARGET} -DARMCOMPUTE_ROOT=${ARMCOMPUTE_DIR} -DHELPERS_armcompute=${HAS_ARMCOMPUTE}"
|
||||
message $cmake_cmd
|
||||
eval $cmake_cmd
|
||||
|
||||
#build
|
||||
message "lets build"
|
||||
|
||||
cd ${LIBND4J_BUILD_DIR}
|
||||
make -j $(nproc)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -52,14 +52,19 @@ elseif(WIN32)
|
|||
set(CMAKE_CXX_FLAGS " -fPIC")
|
||||
endif()
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
|
||||
set(CMAKE_CXX_FLAGS " -fPIC")
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
|
||||
IF(${SD_ARCH} MATCHES "arm*")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=${SD_ARCH}")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
|
||||
|
||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
|
||||
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native")
|
||||
else()
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native")
|
||||
endif()
|
||||
|
||||
endif()
|
||||
if (SD_CPU AND SD_SANITIZE)
|
||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address")
|
||||
else()
|
||||
|
@ -130,7 +135,7 @@ if (SD_CPU)
|
|||
endif()
|
||||
|
||||
add_executable(runtests ${TEST_SOURCES})
|
||||
target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} gtest gtest_main)
|
||||
target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} ${ARMCOMPUTE_LIBRARIES} gtest gtest_main)
|
||||
elseif(SD_CUDA)
|
||||
|
||||
add_executable(runtests ${TEST_SOURCES})
|
||||
|
|
|
@ -1113,7 +1113,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result.at(0);
|
||||
|
||||
#if 0
|
||||
exp.printIndexedBuffer("Expected");
|
||||
z->printIndexedBuffer("Z");
|
||||
#endif
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
@ -1132,7 +1135,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result.at(0);
|
||||
|
||||
#if 0
|
||||
exp.printIndexedBuffer("Expected");
|
||||
z->printIndexedBuffer("Z");
|
||||
#endif
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
@ -1151,7 +1157,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
|
||||
auto z = result.at(0);
|
||||
|
||||
#if 0
|
||||
exp.printIndexedBuffer("Expected");
|
||||
z->printIndexedBuffer("Z");
|
||||
#endif
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
}
|
||||
|
@ -1204,7 +1213,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) {
|
|||
auto* output = results.at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
|
||||
#if 0
|
||||
expOutput.printIndexedBuffer("expOutput");
|
||||
output->printIndexedBuffer("output");
|
||||
#endif
|
||||
ASSERT_TRUE(expOutput.isSameShape(output));
|
||||
ASSERT_TRUE(expOutput.equalsTo(output));
|
||||
}
|
||||
|
|
|
@ -244,7 +244,8 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) {
|
|||
#ifdef _RELEASE
|
||||
TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
|
||||
// [2,1,135079944,1,1,8192,1,99]
|
||||
auto initial = NDArrayFactory::create<float>('c', {1, 135079944});
|
||||
constexpr int sizeX= 10*1000*1000;
|
||||
auto initial = NDArrayFactory::create<float>('c', {1, sizeX});
|
||||
initial = 1.0f;
|
||||
auto exp = initial.dup();
|
||||
auto neg = initial.like();
|
||||
|
@ -254,7 +255,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
|
|||
auto enc_result = enc.evaluate({&initial}, {0.5f});
|
||||
auto encoded = enc_result.at(1);
|
||||
|
||||
ASSERT_EQ(135079944 + 4, encoded->lengthOf());
|
||||
ASSERT_EQ(sizeX + 4, encoded->lengthOf());
|
||||
ASSERT_NE(exp, initial);
|
||||
/*
|
||||
for (int e = 0; e < initial.lengthOf(); e++) {
|
||||
|
@ -419,3 +420,4 @@ TEST_F(DeclarableOpsTests19, test_squeeze_1) {
|
|||
auto status = op.execute({&x}, {&e}, {axis});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
||||
|
||||
|
|
|
@ -1557,8 +1557,6 @@ TEST_F(DeclarableOpsTests5, trace_test1) {
|
|||
// exp.printIndexedBuffer("EXP TRACE");
|
||||
// output->printIndexedBuffer("OUT TRACE");
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1575,8 +1573,6 @@ TEST_F(DeclarableOpsTests5, trace_test2) {
|
|||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1593,8 +1589,6 @@ TEST_F(DeclarableOpsTests5, trace_test3) {
|
|||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1611,8 +1605,6 @@ TEST_F(DeclarableOpsTests5, trace_test4) {
|
|||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1629,8 +1621,6 @@ TEST_F(DeclarableOpsTests5, trace_test5) {
|
|||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1638,22 +1628,15 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) {
|
|||
|
||||
auto input = NDArrayFactory::create<double>('c', {2, 2, 2});
|
||||
input.linspace(1);
|
||||
NDArray exp1 = input.dup();
|
||||
NDArray exp2('c',{2,2,2}, {5,6,7,8, 1,2,3,4}, sd::DataType::DOUBLE);
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input});
|
||||
auto output = results.at(0);
|
||||
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
|
||||
ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1661,16 +1644,14 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test2) {
|
|||
|
||||
auto input = NDArrayFactory::create<double>('c', {1, 3, 2});
|
||||
input.linspace(1);
|
||||
NDArray exp1 = input.dup();
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input});
|
||||
auto output = results.at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(input.equalsTo(output));
|
||||
|
||||
|
||||
ASSERT_TRUE(output->equalsTo(exp1));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1678,129 +1659,132 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) {
|
|||
|
||||
auto input = NDArrayFactory::create<double>('c', {3, 2, 1});
|
||||
input.linspace(1);
|
||||
NDArray exp1 = input.dup();
|
||||
NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE);
|
||||
NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE);
|
||||
NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE);
|
||||
NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE);
|
||||
NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE);
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input});
|
||||
auto output = results.at(0);
|
||||
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test04) {
|
||||
auto input = NDArrayFactory::create<double>('c', {4});
|
||||
input.linspace(1);
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
//NDArray* output;
|
||||
auto results = op.evaluate({&input}, {}, {}, {}, {}, true);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
auto output = &input; //results.at(0);
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
//ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
|
||||
ASSERT_TRUE(input.equalsTo(exp1) || input.equalsTo(exp2) || input.equalsTo(exp3)
|
||||
|| input.equalsTo(exp4) || input.equalsTo(exp5) || input.equalsTo(exp6));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
|
||||
auto input = NDArrayFactory::create<double>('c', {4});
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {3, 2, 1});
|
||||
input.linspace(1);
|
||||
NDArray exp1 = input.dup();
|
||||
NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE);
|
||||
NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE);
|
||||
NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE);
|
||||
NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE);
|
||||
NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE);
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
//NDArray* output;
|
||||
auto results = op.evaluate({&input});
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
auto output = results.at(0);
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
//ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2) || output->equalsTo(exp3)
|
||||
|| output->equalsTo(exp4) || output->equalsTo(exp5) || output->equalsTo(exp6));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test5) {
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {4,1});
|
||||
auto input = NDArrayFactory::create<int>('c', {4});
|
||||
input.linspace(1);
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input});
|
||||
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
|
||||
auto output = results.at(0);
|
||||
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
// output->printBuffer();
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
// ASSERT_TRUE(!output->equalsTo(input));
|
||||
|
||||
bool hasDublicates = false;
|
||||
for(int i = 0; i < output->lengthOf() - 1; ++i)
|
||||
for(int j = i+1; j < output->lengthOf(); ++j)
|
||||
if(output->t<int>(i) == output->t<int>(j)) {
|
||||
hasDublicates = true;
|
||||
i = output->lengthOf();
|
||||
break;
|
||||
}
|
||||
ASSERT_TRUE(!hasDublicates);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test6) {
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {4,1,1});
|
||||
auto input = NDArrayFactory::create<int>('c', {4,1,1});
|
||||
input.linspace(1);
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input});
|
||||
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
|
||||
auto output = results.at(0);
|
||||
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
// ASSERT_TRUE(!output->equalsTo(input));
|
||||
|
||||
bool hasDublicates = false;
|
||||
for(int i = 0; i < output->lengthOf() - 1; ++i)
|
||||
for(int j = i+1; j < output->lengthOf(); ++j)
|
||||
if(output->t<int>(i) == output->t<int>(j)) {
|
||||
hasDublicates = true;
|
||||
i = output->lengthOf();
|
||||
break;
|
||||
}
|
||||
ASSERT_TRUE(!hasDublicates);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test7) {
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {1,4});
|
||||
auto input = NDArrayFactory::create<int>('c', {16010});
|
||||
input.linspace(1);
|
||||
auto exp = NDArrayFactory::create<double>('c', {1,4}, {1, 2, 3, 4});
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input});
|
||||
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
|
||||
auto output = results.at(0);
|
||||
|
||||
// output->printBuffer();
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(input.equalsTo(output));
|
||||
ASSERT_TRUE(!output->equalsTo(input));
|
||||
|
||||
auto vec1 = input.getBufferAsVector<int>();
|
||||
auto vec2 = output->getBufferAsVector<int>();
|
||||
std::sort(vec2.begin(), vec2.end());
|
||||
ASSERT_TRUE(std::equal(vec1.begin(), vec1.end(), vec2.begin()));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test8) {
|
||||
auto input = NDArrayFactory::create<int>('c', {1,4,1});
|
||||
input.linspace(1);
|
||||
NDArray inCopy = input.dup();
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.equalsTo(inCopy));
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test9) {
|
||||
|
||||
auto x = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
||||
auto z = x.ulike();
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto status = op.execute({&x}, {&z});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
auto vec = z.getBufferAsVector<int>();
|
||||
std::sort(vec.begin(), vec.end());
|
||||
ASSERT_EQ(std::vector<int>({1, 2, 3, 4}), vec);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -251,11 +251,10 @@ TEST_F(DeclarableOpsTests9, concat_test1) {
|
|||
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
auto output = result.at(0);
|
||||
// output->printCurrentBuffer<float>(false);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -317,7 +317,7 @@ void fill_random(sd::NDArray& arr) {
|
|||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void testLegacy(bool random) {
|
||||
#if 0
|
||||
int bases[] = { 3, 2, 4, 5, 7 };
|
||||
|
@ -364,7 +364,7 @@ int k = 4;
|
|||
#endif
|
||||
auto dim = NDArrayFactory::create<int>(dimension);
|
||||
|
||||
#if 1
|
||||
#if 1
|
||||
nd4j_printf("C(N:%d K:%d) \n", N, k);
|
||||
dim.printIndexedBuffer("Dimension");
|
||||
for (int xind : dimension) {
|
||||
|
@ -385,7 +385,7 @@ for (int e = 0; e < Loop; e++) {
|
|||
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
|
||||
values.emplace_back(outerTime);
|
||||
}
|
||||
|
||||
|
||||
std::sort(values.begin(), values.end());
|
||||
|
||||
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
||||
|
@ -411,7 +411,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
|
|||
constexpr int N = 5;
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
arr_dimensions.push_back(bases[i]);
|
||||
}
|
||||
|
@ -451,7 +451,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
|
|||
#endif
|
||||
auto dim = NDArrayFactory::create<int>(dimension);
|
||||
|
||||
#if 1
|
||||
#if 1
|
||||
nd4j_printf("C(N:%d K:%d) \n", N, k);
|
||||
dim.printIndexedBuffer("Dimension");
|
||||
for (int xind : dimension) {
|
||||
|
@ -477,14 +477,14 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
|
|||
//check for the correctness
|
||||
NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create<Nd4jLong>('c', output_bases) : NDArrayFactory::create<Nd4jLong>(0);
|
||||
original_argmax(x, dimension, exp);
|
||||
|
||||
|
||||
|
||||
#if 0// defined(DEBUG)
|
||||
x.printIndexedBuffer("X");
|
||||
exp.printIndexedBuffer("Expected");
|
||||
z->printIndexedBuffer("Z");
|
||||
#endif
|
||||
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
}
|
||||
|
@ -505,7 +505,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfLinspace) {
|
|||
testNewReduction(false, test_corr);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
|
||||
testNewReduction(true, test_corr);
|
||||
}
|
||||
|
@ -513,7 +513,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
|
|||
TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) {
|
||||
testNewReduction(true, test_corr, 'f');
|
||||
}
|
||||
|
||||
|
||||
#if !defined(DEBUG)
|
||||
TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) {
|
||||
testLegacy(false);
|
||||
|
@ -1062,39 +1062,6 @@ TEST_F(PlaygroundTests, my) {
|
|||
delete variableSpace;
|
||||
}
|
||||
|
||||
TEST_F(PlaygroundTests, my) {
|
||||
|
||||
int N = 100;
|
||||
int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
||||
int oH=128,oW=128;
|
||||
|
||||
int paddingMode = 1; // 1-SAME, 0-VALID;
|
||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||
|
||||
// NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32);
|
||||
// NDArray output('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32);
|
||||
NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32);
|
||||
NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32);
|
||||
// NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
NDArray weights('c', {oC, iC, kH, kW}, sd::DataType::FLOAT32);
|
||||
NDArray bias('c', {oC}, sd::DataType::FLOAT32);
|
||||
|
||||
input = 5.;
|
||||
weights = 3.;
|
||||
bias = 1.;
|
||||
|
||||
sd::ops::conv2d op;
|
||||
auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||
|
||||
auto timeStart = std::chrono::system_clock::now();
|
||||
for (int i = 0; i < N; ++i)
|
||||
err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||
auto timeEnd = std::chrono::system_clock::now();
|
||||
auto time = std::chrono::duration_cast<std::chrono::microseconds> ((timeEnd - timeStart) / N).count();
|
||||
|
||||
printf("time: %i \n", time);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(PlaygroundTests, lstmLayerCellBp_1) {
|
||||
|
||||
|
@ -1690,6 +1657,52 @@ TEST_F(DeclarableOpsTests15, gru_bp_1) {
|
|||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
}
|
||||
|
||||
#include<ops/declarable/helpers/transforms.h>
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(PlaygroundTests, my) {
|
||||
|
||||
const int N = 10;
|
||||
|
||||
NDArray input('c', {8000000}, sd::DataType::INT32);
|
||||
input.linspace(1);
|
||||
NDArray output = input.dup();
|
||||
|
||||
|
||||
sd::graph::RandomGenerator rng;
|
||||
|
||||
sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true);
|
||||
|
||||
// auto timeStart = std::chrono::system_clock::now();
|
||||
// for (int i = 0; i < N; ++i)
|
||||
// sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true);
|
||||
// auto timeEnd = std::chrono::system_clock::now();
|
||||
// auto time = std::chrono::duration_cast<std::chrono::microseconds> ((timeEnd - timeStart) / N).count();
|
||||
// printf("time: %i \n", time);
|
||||
|
||||
// bool hasDublicates = false;
|
||||
// for(int i = 0; i < output.lengthOf() - 1; ++i)
|
||||
// for(int j = i+1; j < output.lengthOf(); ++j)
|
||||
// if(output.t<int>(i) == output.t<int>(j)) {
|
||||
// hasDublicates = true;
|
||||
// i = output.lengthOf();
|
||||
// break;
|
||||
// }
|
||||
|
||||
ASSERT_TRUE(!input.equalsTo(output));
|
||||
|
||||
bool hasDublicates = false;
|
||||
for(int i = 0; i < input.lengthOf() - 1; ++i)
|
||||
for(int j = i+1; j < input.lengthOf(); ++j)
|
||||
if(input.t<int>(i) == input.t<int>(j)) {
|
||||
hasDublicates = true;
|
||||
i = input.lengthOf();
|
||||
break;
|
||||
}
|
||||
ASSERT_TRUE(!hasDublicates);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef LIBND4J_SESSIONLOCALTESTS_H
|
||||
#define LIBND4J_SESSIONLOCALTESTS_H
|
||||
|
||||
#include "testlayers.h"
|
||||
#include <array/NDArrayFactory.h>
|
||||
#include <graph/SessionLocalStorage.h>
|
||||
|
||||
using namespace sd::graph;
|
||||
|
||||
class SessionLocalTests : public testing::Test {
|
||||
public:
|
||||
|
||||
};
|
||||
|
||||
TEST_F(SessionLocalTests, BasicTests_1) {
|
||||
VariableSpace variableSpace;
|
||||
SessionLocalStorage storage(&variableSpace, nullptr);
|
||||
|
||||
if (omp_get_max_threads() <= 1)
|
||||
return;
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(4)
|
||||
for (int e = 0; e < 4; e++) {
|
||||
storage.startSession();
|
||||
}
|
||||
|
||||
ASSERT_EQ(4, storage.numberOfSessions());
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(4)
|
||||
for (int e = 0; e < 4; e++) {
|
||||
storage.endSession();
|
||||
}
|
||||
|
||||
ASSERT_EQ(0, storage.numberOfSessions());
|
||||
}
|
||||
|
||||
|
||||
TEST_F(SessionLocalTests, BasicTests_2) {
|
||||
VariableSpace variableSpace;
|
||||
SessionLocalStorage storage(&variableSpace, nullptr);
|
||||
|
||||
if (omp_get_max_threads() <= 1)
|
||||
return;
|
||||
|
||||
auto alpha = sd::NDArrayFactory::create_<float>('c',{5,5});
|
||||
alpha->assign(0.0);
|
||||
|
||||
variableSpace.putVariable(-1, alpha);
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(4)
|
||||
for (int e = 0; e < 4; e++) {
|
||||
storage.startSession();
|
||||
|
||||
auto varSpace = storage.localVariableSpace();
|
||||
|
||||
auto arr = varSpace->getVariable(-1)->getNDArray();
|
||||
arr->applyScalar(sd::scalar::Add, (float) e+1, *arr);
|
||||
}
|
||||
|
||||
float lastValue = 0.0f;
|
||||
for (int e = 1; e <= 4; e++) {
|
||||
auto varSpace = storage.localVariableSpace((Nd4jLong) e);
|
||||
|
||||
auto arr = varSpace->getVariable(-1)->getNDArray();
|
||||
|
||||
//nd4j_printf("Last value: %f; Current value: %f\n", lastValue, arr->e(0));
|
||||
|
||||
ASSERT_NE(lastValue, arr->e<float>(0));
|
||||
lastValue = arr->e<float>(0);
|
||||
}
|
||||
}
|
||||
|
||||
#endif //LIBND4J_SESSIONLOCALTESTS_H
|
|
@ -45,6 +45,21 @@ if ("${BUILD_MKLDNN}")
|
|||
set(MKLDNN dnnl)
|
||||
endif()
|
||||
|
||||
if (${HELPERS_armcompute})
|
||||
find_package(ARMCOMPUTE REQUIRED)
|
||||
|
||||
if(ARMCOMPUTE_FOUND)
|
||||
message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}")
|
||||
set(HAVE_ARMCOMPUTE 1)
|
||||
# Add preprocessor definition for ARM Compute NEON
|
||||
add_definitions(-DARMCOMPUTENEON_ENABLED)
|
||||
#build our library with neon support
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon")
|
||||
include_directories(${ARMCOMPUTE_INCLUDE})
|
||||
endif()
|
||||
|
||||
endif()
|
||||
|
||||
# Download and unpack flatbuffers at configure time
|
||||
configure_file(../../CMakeLists.txt.in flatbuffers-download/CMakeLists.txt)
|
||||
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
|
||||
|
@ -217,6 +232,10 @@ if ("${BUILD_MKLDNN}")
|
|||
file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../../include/ops/declarable/platform/mkldnn/*.cpp)
|
||||
endif()
|
||||
|
||||
if(HAVE_ARMCOMPUTE)
|
||||
file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/armcomputeUtils.h)
|
||||
endif()
|
||||
|
||||
message("CPU backend")
|
||||
add_definitions(-D__CPUBLAS__=true)
|
||||
|
||||
|
@ -276,8 +295,9 @@ endforeach(TMP_PATH)
|
|||
|
||||
|
||||
add_executable(runtests ${LOOPS_SOURCES} ${LEGACY_SOURCES} ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES}
|
||||
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||
${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES})
|
||||
|
||||
target_link_libraries(runtests gtest ${MKLDNN} gtest_main ${BLAS_LIBRARIES})
|
||||
target_link_libraries(runtests gtest ${MKLDNN} ${ARMCOMPUTE_LIBRARIES} gtest_main ${BLAS_LIBRARIES})
|
||||
|
||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -25,7 +25,7 @@
|
|||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<groupId>org.eclipse</groupId>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>python4j-parent</artifactId>
|
||||
<packaging>pom</packaging>
|
||||
<modules>
|
||||
|
@ -41,10 +41,14 @@
|
|||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>1.6.6</version>
|
||||
</dependency> <dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
<version>${logback.version}</version>
|
||||
<scope>test</scope>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
|
@ -62,5 +66,10 @@
|
|||
<artifactId>jsr305</artifactId>
|
||||
<version>3.0.2</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.slf4j</groupId>
|
||||
<artifactId>slf4j-api</artifactId>
|
||||
<version>1.6.6</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</project>
|
|
@ -21,7 +21,7 @@
|
|||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<parent>
|
||||
<artifactId>python4j-parent</artifactId>
|
||||
<groupId>org.eclipse</groupId>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<packaging>jar</packaging>
|
||||
|
@ -39,6 +39,5 @@
|
|||
<artifactId>cpython-platform</artifactId>
|
||||
<version>${cpython-platform.version}</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
</project>
|
|
@ -15,7 +15,7 @@
|
|||
******************************************************************************/
|
||||
|
||||
|
||||
package org.eclipse.python4j;
|
||||
package org.nd4j.python4j;
|
||||
|
||||
import org.bytedeco.cpython.PyObject;
|
||||
|
|
@ -14,13 +14,15 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.eclipse.python4j;
|
||||
package org.nd4j.python4j;
|
||||
|
||||
import javax.lang.model.SourceVersion;
|
||||
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
/**
|
||||
|
@ -46,6 +48,31 @@ public class PythonContextManager {
|
|||
init();
|
||||
}
|
||||
|
||||
|
||||
public static class Context implements Closeable{
|
||||
private final String name;
|
||||
private final String previous;
|
||||
private final boolean temp;
|
||||
public Context(){
|
||||
name = "temp_" + UUID.randomUUID().toString().replace("-", "_");
|
||||
temp = true;
|
||||
previous = getCurrentContext();
|
||||
setContext(name);
|
||||
}
|
||||
public Context(String name){
|
||||
this.name = name;
|
||||
temp = false;
|
||||
previous = getCurrentContext();
|
||||
setContext(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close(){
|
||||
setContext(previous);
|
||||
if (temp) deleteContext(name);
|
||||
}
|
||||
}
|
||||
|
||||
private static void init() {
|
||||
if (init.get()) return;
|
||||
new PythonExecutioner();
|
||||
|
@ -76,7 +103,18 @@ public class PythonContextManager {
|
|||
}
|
||||
|
||||
private static boolean validateContextName(String s) {
|
||||
return SourceVersion.isIdentifier(s) && !s.startsWith(COLLAPSED_KEY);
|
||||
for (int i=0; i<s.length(); i++){
|
||||
char c = s.toLowerCase().charAt(i);
|
||||
if (i == 0){
|
||||
if (c >= '0' && c <= '9'){
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!(c=='_' || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9'))){
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
private static String getContextPrefix(String contextName) {
|
||||
|
@ -190,6 +228,7 @@ public class PythonContextManager {
|
|||
setContext(tempContext);
|
||||
deleteContext(currContext);
|
||||
setContext(currContext);
|
||||
deleteContext(tempContext);
|
||||
}
|
||||
|
||||
/**
|
|
@ -14,7 +14,7 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.eclipse.python4j;
|
||||
package org.nd4j.python4j;
|
||||
|
||||
|
||||
/**
|
|
@ -15,7 +15,7 @@
|
|||
******************************************************************************/
|
||||
|
||||
|
||||
package org.eclipse.python4j;
|
||||
package org.nd4j.python4j;
|
||||
|
||||
import org.bytedeco.cpython.PyObject;
|
||||
|
||||
|
@ -42,7 +42,6 @@ public class PythonExecutioner {
|
|||
private final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.eclipse.python4j.path";
|
||||
private final static String JAVACPP_PYTHON_APPEND_TYPE = "org.eclipse.python4j.path.append";
|
||||
private final static String DEFAULT_APPEND_TYPE = "before";
|
||||
|
||||
static {
|
||||
init();
|
||||
}
|
||||
|
@ -55,6 +54,11 @@ public class PythonExecutioner {
|
|||
initPythonPath();
|
||||
PyEval_InitThreads();
|
||||
Py_InitializeEx(0);
|
||||
for (PythonType type: PythonTypes.get()){
|
||||
type.init();
|
||||
}
|
||||
// Constructors of custom types may contain initialization code that should
|
||||
// run on the main the thread.
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -110,6 +114,8 @@ public class PythonExecutioner {
|
|||
getVariables(Arrays.asList(pyVars));
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Gets the variable with the given name from the interpreter.
|
||||
*
|
||||
|
@ -205,9 +211,9 @@ public class PythonExecutioner {
|
|||
*
|
||||
* @return
|
||||
*/
|
||||
public static List<PythonVariable> getAllVariables() {
|
||||
public static PythonVariables getAllVariables() {
|
||||
PythonGIL.assertThreadSafe();
|
||||
List<PythonVariable> ret = new ArrayList<>();
|
||||
PythonVariables ret = new PythonVariables();
|
||||
PyObject main = PyImport_ImportModule("__main__");
|
||||
PyObject globals = PyModule_GetDict(main);
|
||||
PyObject keys = PyDict_Keys(globals);
|
||||
|
@ -259,7 +265,7 @@ public class PythonExecutioner {
|
|||
* @param inputs
|
||||
* @return
|
||||
*/
|
||||
public static List<PythonVariable> execAndReturnAllVariables(String code, List<PythonVariable> inputs) {
|
||||
public static PythonVariables execAndReturnAllVariables(String code, List<PythonVariable> inputs) {
|
||||
setVariables(inputs);
|
||||
simpleExec(getWrappedCode(code));
|
||||
return getAllVariables();
|
||||
|
@ -271,7 +277,7 @@ public class PythonExecutioner {
|
|||
* @param code
|
||||
* @return
|
||||
*/
|
||||
public static List<PythonVariable> execAndReturnAllVariables(String code) {
|
||||
public static PythonVariables execAndReturnAllVariables(String code) {
|
||||
simpleExec(getWrappedCode(code));
|
||||
return getAllVariables();
|
||||
}
|
||||
|
@ -279,25 +285,22 @@ public class PythonExecutioner {
|
|||
private static synchronized void initPythonPath() {
|
||||
try {
|
||||
String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY);
|
||||
|
||||
List<File> packagesList = new ArrayList<>();
|
||||
packagesList.addAll(Arrays.asList(cachePackages()));
|
||||
for (PythonType type: PythonTypes.get()){
|
||||
packagesList.addAll(Arrays.asList(type.packages()));
|
||||
}
|
||||
//// TODO: fix in javacpp
|
||||
packagesList.add(new File(python.cachePackage(), "site-packages"));
|
||||
|
||||
File[] packages = packagesList.toArray(new File[0]);
|
||||
|
||||
if (path == null) {
|
||||
File[] packages = cachePackages();
|
||||
|
||||
//// TODO: fix in javacpp
|
||||
File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
|
||||
File[] packages2 = new File[packages.length + 1];
|
||||
for (int i = 0; i < packages.length; i++) {
|
||||
//System.out.println(packages[i].getAbsolutePath());
|
||||
packages2[i] = packages[i];
|
||||
}
|
||||
packages2[packages.length] = sitePackagesWindows;
|
||||
//System.out.println(sitePackagesWindows.getAbsolutePath());
|
||||
packages = packages2;
|
||||
//////////
|
||||
|
||||
Py_SetPath(packages);
|
||||
} else {
|
||||
StringBuffer sb = new StringBuffer();
|
||||
File[] packages = cachePackages();
|
||||
|
||||
JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase());
|
||||
switch (pathAppendValue) {
|
||||
case BEFORE:
|
|
@ -15,7 +15,7 @@
|
|||
******************************************************************************/
|
||||
|
||||
|
||||
package org.eclipse.python4j;
|
||||
package org.nd4j.python4j;
|
||||
|
||||
import org.bytedeco.cpython.PyObject;
|
||||
import org.bytedeco.javacpp.Pointer;
|
|
@ -14,11 +14,10 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.eclipse.python4j;
|
||||
package org.nd4j.python4j;
|
||||
|
||||
|
||||
import org.bytedeco.cpython.PyThreadState;
|
||||
import org.omg.SendingContext.RunTime;
|
||||
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
|
@ -90,4 +89,8 @@ public class PythonGIL implements AutoCloseable {
|
|||
PyEval_SaveThread();
|
||||
PyEval_RestoreThread(mainThreadState);
|
||||
}
|
||||
|
||||
public static boolean locked(){
|
||||
return acquired.get();
|
||||
}
|
||||
}
|
|
@ -14,31 +14,34 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.eclipse.python4j;
|
||||
package org.nd4j.python4j;
|
||||
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import javax.annotation.Nonnull;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor
|
||||
/**
|
||||
* PythonJob is the right abstraction for executing multiple python scripts
|
||||
* in a multi thread stateful environment. The setup-and-run mode allows your
|
||||
* "setup" code (imports, model loading etc) to be executed only once.
|
||||
*/
|
||||
@Data
|
||||
@Slf4j
|
||||
public class PythonJob {
|
||||
|
||||
|
||||
private String code;
|
||||
private String name;
|
||||
private String context;
|
||||
private boolean setupRunMode;
|
||||
private final boolean setupRunMode;
|
||||
private PythonObject runF;
|
||||
private final AtomicBoolean setupDone = new AtomicBoolean(false);
|
||||
|
||||
static {
|
||||
new PythonExecutioner();
|
||||
|
@ -63,7 +66,6 @@ public class PythonJob {
|
|||
if (PythonContextManager.hasContext(context)) {
|
||||
throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!");
|
||||
}
|
||||
if (setupRunMode) setup();
|
||||
}
|
||||
|
||||
|
||||
|
@ -71,17 +73,18 @@ public class PythonJob {
|
|||
* Clears all variables in current context and calls setup()
|
||||
*/
|
||||
public void clearState(){
|
||||
String context = this.context;
|
||||
PythonContextManager.setContext("main");
|
||||
PythonContextManager.deleteContext(context);
|
||||
this.context = context;
|
||||
PythonContextManager.setContext(this.context);
|
||||
PythonContextManager.reset();
|
||||
setupDone.set(false);
|
||||
setup();
|
||||
}
|
||||
|
||||
public void setup(){
|
||||
if (setupDone.get()) return;
|
||||
try (PythonGIL gil = PythonGIL.lock()) {
|
||||
PythonContextManager.setContext(context);
|
||||
PythonObject runF = PythonExecutioner.getVariable("run");
|
||||
|
||||
if (runF == null || runF.isNone() || !Python.callable(runF)) {
|
||||
PythonExecutioner.exec(code);
|
||||
runF = PythonExecutioner.getVariable("run");
|
||||
|
@ -98,10 +101,12 @@ public class PythonJob {
|
|||
if (!setupF.isNone()) {
|
||||
setupF.call();
|
||||
}
|
||||
setupDone.set(true);
|
||||
}
|
||||
}
|
||||
|
||||
public void exec(List<PythonVariable> inputs, List<PythonVariable> outputs) {
|
||||
if (setupRunMode)setup();
|
||||
try (PythonGIL gil = PythonGIL.lock()) {
|
||||
try (PythonGC _ = PythonGC.watch()) {
|
||||
PythonContextManager.setContext(context);
|
||||
|
@ -139,6 +144,7 @@ public class PythonJob {
|
|||
}
|
||||
|
||||
public List<PythonVariable> execAndReturnAllVariables(List<PythonVariable> inputs){
|
||||
if (setupRunMode)setup();
|
||||
try (PythonGIL gil = PythonGIL.lock()) {
|
||||
try (PythonGC _ = PythonGC.watch()) {
|
||||
PythonContextManager.setContext(context);
|
|
@ -14,7 +14,7 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.eclipse.python4j;
|
||||
package org.nd4j.python4j;
|
||||
|
||||
|
||||
import org.bytedeco.cpython.PyObject;
|
||||
|
@ -147,7 +147,8 @@ public class PythonObject {
|
|||
}
|
||||
PythonObject pyArgs;
|
||||
PythonObject pyKwargs;
|
||||
if (args == null) {
|
||||
|
||||
if (args == null || args.isEmpty()) {
|
||||
pyArgs = new PythonObject(PyTuple_New(0));
|
||||
} else {
|
||||
PythonObject argsList = PythonTypes.convert(args);
|
||||
|
@ -158,6 +159,7 @@ public class PythonObject {
|
|||
} else {
|
||||
pyKwargs = PythonTypes.convert(kwargs);
|
||||
}
|
||||
|
||||
PythonObject ret = new PythonObject(
|
||||
PyObject_Call(
|
||||
nativePythonObject,
|
||||
|
@ -165,7 +167,9 @@ public class PythonObject {
|
|||
pyKwargs == null ? null : pyKwargs.nativePythonObject
|
||||
)
|
||||
);
|
||||
|
||||
PythonGC.keep(ret);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
@ -241,4 +245,48 @@ public class PythonObject {
|
|||
PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject);
|
||||
}
|
||||
|
||||
|
||||
public PythonObject abs(){
|
||||
return new PythonObject(PyNumber_Absolute(nativePythonObject));
|
||||
}
|
||||
public PythonObject add(PythonObject pythonObject){
|
||||
return new PythonObject(PyNumber_Add(nativePythonObject, pythonObject.nativePythonObject));
|
||||
}
|
||||
public PythonObject sub(PythonObject pythonObject){
|
||||
return new PythonObject(PyNumber_Subtract(nativePythonObject, pythonObject.nativePythonObject));
|
||||
}
|
||||
public PythonObject mod(PythonObject pythonObject){
|
||||
return new PythonObject(PyNumber_Divmod(nativePythonObject, pythonObject.nativePythonObject));
|
||||
}
|
||||
public PythonObject mul(PythonObject pythonObject){
|
||||
return new PythonObject(PyNumber_Multiply(nativePythonObject, pythonObject.nativePythonObject));
|
||||
}
|
||||
public PythonObject trueDiv(PythonObject pythonObject){
|
||||
return new PythonObject(PyNumber_TrueDivide(nativePythonObject, pythonObject.nativePythonObject));
|
||||
}
|
||||
public PythonObject floorDiv(PythonObject pythonObject){
|
||||
return new PythonObject(PyNumber_FloorDivide(nativePythonObject, pythonObject.nativePythonObject));
|
||||
}
|
||||
public PythonObject matMul(PythonObject pythonObject){
|
||||
return new PythonObject(PyNumber_MatrixMultiply(nativePythonObject, pythonObject.nativePythonObject));
|
||||
}
|
||||
|
||||
public void addi(PythonObject pythonObject){
|
||||
PyNumber_InPlaceAdd(nativePythonObject, pythonObject.nativePythonObject);
|
||||
}
|
||||
public void subi(PythonObject pythonObject){
|
||||
PyNumber_InPlaceSubtract(nativePythonObject, pythonObject.nativePythonObject);
|
||||
}
|
||||
public void muli(PythonObject pythonObject){
|
||||
PyNumber_InPlaceMultiply(nativePythonObject, pythonObject.nativePythonObject);
|
||||
}
|
||||
public void trueDivi(PythonObject pythonObject){
|
||||
PyNumber_InPlaceTrueDivide(nativePythonObject, pythonObject.nativePythonObject);
|
||||
}
|
||||
public void floorDivi(PythonObject pythonObject){
|
||||
PyNumber_InPlaceFloorDivide(nativePythonObject, pythonObject.nativePythonObject);
|
||||
}
|
||||
public void matMuli(PythonObject pythonObject){
|
||||
PyNumber_InPlaceMatrixMultiply(nativePythonObject, pythonObject.nativePythonObject);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
package org.nd4j.python4j;
|
||||
|
||||
import org.apache.commons.io.IOUtils;
|
||||
import org.bytedeco.javacpp.Loader;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
|
||||
public class PythonProcess {
|
||||
private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
|
||||
public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
|
||||
String[] allArgs = new String[arguments.length + 1];
|
||||
for (int i = 0; i < arguments.length; i++){
|
||||
allArgs[i + 1] = arguments[i];
|
||||
}
|
||||
allArgs[0] = pythonExecutable;
|
||||
ProcessBuilder pb = new ProcessBuilder(allArgs);
|
||||
Process process = pb.start();
|
||||
String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
|
||||
process.waitFor();
|
||||
return out;
|
||||
|
||||
}
|
||||
|
||||
public static void run(String... arguments)throws IOException, InterruptedException{
|
||||
String[] allArgs = new String[arguments.length + 1];
|
||||
for (int i = 0; i < arguments.length; i++){
|
||||
allArgs[i + 1] = arguments[i];
|
||||
}
|
||||
allArgs[0] = pythonExecutable;
|
||||
ProcessBuilder pb = new ProcessBuilder(allArgs);
|
||||
pb.inheritIO().start().waitFor();
|
||||
}
|
||||
public static void pipInstall(String packageName) throws PythonException{
|
||||
try{
|
||||
run("-m", "pip", "install", packageName);
|
||||
}catch(Exception e){
|
||||
throw new PythonException("Error installing package " + packageName, e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static void pipInstall(String packageName, String version){
|
||||
pipInstall(packageName + "==" + version);
|
||||
}
|
||||
|
||||
public static void pipUninstall(String packageName) throws PythonException{
|
||||
try{
|
||||
run("-m", "pip", "uninstall", packageName);
|
||||
}catch(Exception e){
|
||||
throw new PythonException("Error uninstalling package " + packageName, e);
|
||||
}
|
||||
|
||||
}
|
||||
public static void pipInstallFromGit(String gitRepoUrl){
|
||||
if (!gitRepoUrl.contains("://")){
|
||||
gitRepoUrl = "git://" + gitRepoUrl;
|
||||
}
|
||||
try{
|
||||
run("-m", "pip", "install", "git+", gitRepoUrl);
|
||||
}catch(Exception e){
|
||||
throw new PythonException("Error installing package from " + gitRepoUrl, e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static String getPackageVersion(String packageName){
|
||||
String out;
|
||||
try{
|
||||
out = runAndReturn("-m", "pip", "show", packageName);
|
||||
} catch (Exception e){
|
||||
throw new PythonException("Error finding version for package " + packageName, e);
|
||||
}
|
||||
|
||||
if (!out.contains("Version: ")){
|
||||
throw new PythonException("Can't find package " + packageName);
|
||||
}
|
||||
String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
|
||||
return pkgVersion;
|
||||
}
|
||||
|
||||
public static boolean isPackageInstalled(String packageName){
|
||||
try{
|
||||
String out = runAndReturn("-m", "pip", "show", packageName);
|
||||
return !out.isEmpty();
|
||||
}catch (Exception e){
|
||||
throw new PythonException("Error checking if package is installed: " +packageName, e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static void pipInstallFromRequirementsTxt(String path){
|
||||
try{
|
||||
run("-m", "pip", "install","-r", path);
|
||||
}catch (Exception e){
|
||||
throw new PythonException("Error installing packages from " + path, e);
|
||||
}
|
||||
}
|
||||
|
||||
public static void pipInstallFromSetupScript(String path, boolean inplace){
|
||||
|
||||
try{
|
||||
run(path, inplace?"develop":"install");
|
||||
}catch (Exception e){
|
||||
throw new PythonException("Error installing package from " + path, e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -14,9 +14,11 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.eclipse.python4j;
|
||||
package org.nd4j.python4j;
|
||||
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public abstract class PythonType<T> {
|
||||
|
||||
private final String name;
|
||||
|
@ -43,5 +45,25 @@ public abstract class PythonType<T> {
|
|||
return name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj){
|
||||
if (!(obj instanceof PythonType)){
|
||||
return false;
|
||||
}
|
||||
PythonType other = (PythonType)obj;
|
||||
return this.getClass().equals(other.getClass()) && this.name.equals(other.name);
|
||||
}
|
||||
|
||||
public PythonObject pythonType(){
|
||||
return null;
|
||||
}
|
||||
|
||||
public File[] packages(){
|
||||
return new File[0];
|
||||
}
|
||||
|
||||
public void init(){ //not to be called from constructor
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -14,11 +14,18 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.eclipse.python4j;
|
||||
package org.nd4j.python4j;
|
||||
|
||||
|
||||
import org.bytedeco.cpython.PyObject;
|
||||
import org.bytedeco.javacpp.BytePointer;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import sun.nio.ch.DirectBuffer;
|
||||
|
||||
import java.lang.reflect.Field;
|
||||
import java.nio.Buffer;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.*;
|
||||
|
||||
import static org.bytedeco.cpython.global.python.*;
|
||||
|
@ -28,7 +35,7 @@ public class PythonTypes {
|
|||
|
||||
|
||||
private static List<PythonType> getPrimitiveTypes() {
|
||||
return Arrays.<PythonType>asList(STR, INT, FLOAT, BOOL);
|
||||
return Arrays.<PythonType>asList(STR, INT, FLOAT, BOOL, BYTES);
|
||||
}
|
||||
|
||||
private static List<PythonType> getCollectionTypes() {
|
||||
|
@ -36,8 +43,13 @@ public class PythonTypes {
|
|||
}
|
||||
|
||||
private static List<PythonType> getExternalTypes() {
|
||||
//TODO service loader
|
||||
return new ArrayList<>();
|
||||
List<PythonType> ret = new ArrayList<>();
|
||||
ServiceLoader<PythonType> sl = ServiceLoader.load(PythonType.class);
|
||||
Iterator<PythonType> iter = sl.iterator();
|
||||
while (iter.hasNext()) {
|
||||
ret.add(iter.next());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
public static List<PythonType> get() {
|
||||
|
@ -48,15 +60,17 @@ public class PythonTypes {
|
|||
return ret;
|
||||
}
|
||||
|
||||
public static PythonType get(String name) {
|
||||
public static <T> PythonType<T> get(String name) {
|
||||
for (PythonType pt : get()) {
|
||||
if (pt.getName().equals(name)) { // TODO use map instead?
|
||||
return pt;
|
||||
}
|
||||
|
||||
}
|
||||
throw new PythonException("Unknown python type: " + name);
|
||||
}
|
||||
|
||||
|
||||
public static PythonType getPythonTypeForJavaObject(Object javaObject) {
|
||||
for (PythonType pt : get()) {
|
||||
if (pt.accepts(javaObject)) {
|
||||
|
@ -66,7 +80,7 @@ public class PythonTypes {
|
|||
throw new PythonException("Unable to find python type for java type: " + javaObject.getClass());
|
||||
}
|
||||
|
||||
public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) {
|
||||
public static <T> PythonType<T> getPythonTypeForPythonObject(PythonObject pythonObject) {
|
||||
PyObject pyType = PyObject_Type(pythonObject.getNativePythonObject());
|
||||
try {
|
||||
String pyTypeStr = PythonTypes.STR.toJava(new PythonObject(pyType, false));
|
||||
|
@ -75,6 +89,14 @@ public class PythonTypes {
|
|||
String pyTypeStr2 = "<class '" + pt.getName() + "'>";
|
||||
if (pyTypeStr.equals(pyTypeStr2)) {
|
||||
return pt;
|
||||
} else {
|
||||
try (PythonGC gc = PythonGC.watch()) {
|
||||
PythonObject pyType2 = pt.pythonType();
|
||||
if (pyType2 != null && Python.isinstance(pythonObject, pyType2)) {
|
||||
return pt;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
throw new PythonException("Unable to find converter for python object of type " + pyTypeStr);
|
||||
|
@ -212,12 +234,53 @@ public class PythonTypes {
|
|||
|
||||
public static final PythonType<List> LIST = new PythonType<List>("list", List.class) {
|
||||
|
||||
@Override
|
||||
public boolean accepts(Object javaObject) {
|
||||
return (javaObject instanceof List || javaObject.getClass().isArray());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List adapt(Object javaObject) {
|
||||
if (javaObject instanceof List) {
|
||||
return (List) javaObject;
|
||||
} else if (javaObject instanceof Object[]) {
|
||||
return Arrays.asList((Object[]) javaObject);
|
||||
} else if (javaObject.getClass().isArray()) {
|
||||
List<Object> ret = new ArrayList<>();
|
||||
if (javaObject instanceof Object[]) {
|
||||
Object[] arr = (Object[]) javaObject;
|
||||
return new ArrayList<>(Arrays.asList(arr));
|
||||
} else if (javaObject instanceof short[]) {
|
||||
short[] arr = (short[]) javaObject;
|
||||
for (short x : arr) ret.add(x);
|
||||
return ret;
|
||||
} else if (javaObject instanceof int[]) {
|
||||
int[] arr = (int[]) javaObject;
|
||||
for (int x : arr) ret.add(x);
|
||||
return ret;
|
||||
}else if (javaObject instanceof byte[]){
|
||||
byte[] arr = (byte[]) javaObject;
|
||||
for (int x : arr) ret.add(x);
|
||||
return ret;
|
||||
} else if (javaObject instanceof long[]) {
|
||||
long[] arr = (long[]) javaObject;
|
||||
for (long x : arr) ret.add(x);
|
||||
return ret;
|
||||
} else if (javaObject instanceof float[]) {
|
||||
float[] arr = (float[]) javaObject;
|
||||
for (float x : arr) ret.add(x);
|
||||
return ret;
|
||||
} else if (javaObject instanceof double[]) {
|
||||
double[] arr = (double[]) javaObject;
|
||||
for (double x : arr) ret.add(x);
|
||||
return ret;
|
||||
} else if (javaObject instanceof boolean[]) {
|
||||
boolean[] arr = (boolean[]) javaObject;
|
||||
for (boolean x : arr) ret.add(x);
|
||||
return ret;
|
||||
} else {
|
||||
throw new PythonException("Unsupported array type: " + javaObject.getClass().toString());
|
||||
}
|
||||
|
||||
|
||||
} else {
|
||||
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to List");
|
||||
}
|
||||
|
@ -327,7 +390,13 @@ public class PythonTypes {
|
|||
}
|
||||
Object v = javaObject.get(k);
|
||||
PythonObject pyVal;
|
||||
pyVal = PythonTypes.convert(v);
|
||||
if (v instanceof PythonObject) {
|
||||
pyVal = (PythonObject) v;
|
||||
} else if (v instanceof PyObject) {
|
||||
pyVal = new PythonObject((PyObject) v);
|
||||
} else {
|
||||
pyVal = PythonTypes.convert(v);
|
||||
}
|
||||
int errCode = PyDict_SetItem(pyDict, pyKey.getNativePythonObject(), pyVal.getNativePythonObject());
|
||||
if (errCode != 0) {
|
||||
String keyStr = pyKey.toString();
|
||||
|
@ -341,4 +410,127 @@ public class PythonTypes {
|
|||
return new PythonObject(pyDict);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
public static final PythonType<byte[]> BYTES = new PythonType<byte[]>("bytes", byte[].class) {
|
||||
@Override
|
||||
public byte[] toJava(PythonObject pythonObject) {
|
||||
try (PythonGC gc = PythonGC.watch()) {
|
||||
if (!(Python.isinstance(pythonObject, Python.bytesType()))) {
|
||||
throw new PythonException("Expected bytes. Received: " + pythonObject);
|
||||
}
|
||||
PythonObject pySize = Python.len(pythonObject);
|
||||
byte[] ret = new byte[pySize.toInt()];
|
||||
for (int i = 0; i < ret.length; i++) {
|
||||
ret[i] = (byte)pythonObject.get(i).toInt();
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public PythonObject toPython(byte[] javaObject) {
|
||||
try(PythonGC gc = PythonGC.watch()){
|
||||
PythonObject ret = Python.bytes(LIST.toPython(LIST.adapt(javaObject)));
|
||||
PythonGC.keep(ret);
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
@Override
|
||||
public boolean accepts(Object javaObject) {
|
||||
return javaObject instanceof byte[];
|
||||
}
|
||||
@Override
|
||||
public byte[] adapt(Object javaObject) {
|
||||
if (javaObject instanceof byte[]){
|
||||
return (byte[])javaObject;
|
||||
}
|
||||
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to byte[]");
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/**
|
||||
* Crashes on Adopt OpenJDK
|
||||
* Use implementation in python4j-numpy instead for zero-copy byte buffers.
|
||||
*/
|
||||
// public static final PythonType<BytePointer> MEMORYVIEW = new PythonType<BytePointer>("memoryview", BytePointer.class) {
|
||||
// @Override
|
||||
// public BytePointer toJava(PythonObject pythonObject) {
|
||||
// try (PythonGC gc = PythonGC.watch()) {
|
||||
// if (!(Python.isinstance(pythonObject, Python.memoryviewType()))) {
|
||||
// throw new PythonException("Expected memoryview. Received: " + pythonObject);
|
||||
// }
|
||||
// PythonObject pySize = Python.len(pythonObject);
|
||||
// PythonObject ctypes = Python.importModule("ctypes");
|
||||
// PythonObject charType = ctypes.attr("c_char");
|
||||
// PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(),
|
||||
// pySize.getNativePythonObject()));
|
||||
// PythonObject fromBuffer = charArrayType.attr("from_buffer");
|
||||
// if (pythonObject.attr("readonly").toBoolean()) {
|
||||
// pythonObject = Python.bytearray(pythonObject);
|
||||
// }
|
||||
// PythonObject arr = fromBuffer.call(pythonObject);
|
||||
// PythonObject cast = ctypes.attr("cast");
|
||||
// PythonObject voidPtrType = ctypes.attr("c_void_p");
|
||||
// PythonObject voidPtr = cast.call(arr, voidPtrType);
|
||||
// long address = voidPtr.attr("value").toLong();
|
||||
// long size = pySize.toLong();
|
||||
// try {
|
||||
// Field addressField = Buffer.class.getDeclaredField("address");
|
||||
// addressField.setAccessible(true);
|
||||
// Field capacityField = Buffer.class.getDeclaredField("capacity");
|
||||
// capacityField.setAccessible(true);
|
||||
// ByteBuffer buff = ByteBuffer.allocateDirect(0).order(ByteOrder.nativeOrder());
|
||||
// addressField.setLong(buff, address);
|
||||
// capacityField.setInt(buff, (int) size);
|
||||
// BytePointer ret = new BytePointer(buff);
|
||||
// ret.limit(size);
|
||||
// return ret;
|
||||
//
|
||||
// } catch (Exception e) {
|
||||
// throw new RuntimeException(e);
|
||||
// }
|
||||
//
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// @Override
|
||||
// public PythonObject toPython(BytePointer javaObject) {
|
||||
// long address = javaObject.address();
|
||||
// long size = javaObject.limit();
|
||||
// try (PythonGC gc = PythonGC.watch()) {
|
||||
// PythonObject ctypes = Python.importModule("ctypes");
|
||||
// PythonObject charType = ctypes.attr("c_char");
|
||||
// PythonObject pySize = new PythonObject(size);
|
||||
// PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(),
|
||||
// pySize.getNativePythonObject()));
|
||||
// PythonObject fromAddress = charArrayType.attr("from_address");
|
||||
// PythonObject arr = fromAddress.call(new PythonObject(address));
|
||||
// PythonObject memoryView = Python.memoryview(arr).attr("cast").call("b");
|
||||
// PythonGC.keep(memoryView);
|
||||
// return memoryView;
|
||||
// }
|
||||
//
|
||||
// }
|
||||
//
|
||||
// @Override
|
||||
// public boolean accepts(Object javaObject) {
|
||||
// return javaObject instanceof Pointer || javaObject instanceof DirectBuffer;
|
||||
// }
|
||||
//
|
||||
// @Override
|
||||
// public BytePointer adapt(Object javaObject) {
|
||||
// if (javaObject instanceof BytePointer) {
|
||||
// return (BytePointer) javaObject;
|
||||
// } else if (javaObject instanceof Pointer) {
|
||||
// return new BytePointer((Pointer) javaObject);
|
||||
// } else if (javaObject instanceof DirectBuffer) {
|
||||
// return new BytePointer((ByteBuffer) javaObject);
|
||||
// } else {
|
||||
// throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to BytePointer");
|
||||
// }
|
||||
// }
|
||||
// };
|
||||
|
||||
}
|
|
@ -14,7 +14,7 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.eclipse.python4j;
|
||||
package org.nd4j.python4j;
|
||||
|
||||
@lombok.Data
|
||||
public class PythonVariable<T> {
|
|
@ -0,0 +1,47 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.python4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Some syntax sugar for lookup by name
|
||||
*/
|
||||
public class PythonVariables extends ArrayList<PythonVariable> {
|
||||
public PythonVariable get(String variableName) {
|
||||
for (PythonVariable pyVar: this){
|
||||
if (pyVar.getName().equals(variableName)){
|
||||
return pyVar;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public <T> boolean add(String variableName, PythonType<T> variableType, Object value){
|
||||
return this.add(new PythonVariable<>(variableName, variableType, value));
|
||||
}
|
||||
|
||||
public PythonVariables(PythonVariable... variables){
|
||||
this(Arrays.asList(variables));
|
||||
}
|
||||
public PythonVariables(List<PythonVariable> list){
|
||||
super();
|
||||
addAll(list);
|
||||
}
|
||||
}
|
|
@ -15,9 +15,12 @@
|
|||
******************************************************************************/
|
||||
|
||||
|
||||
import org.eclipse.python4j.*;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.python4j.PythonContextManager;
|
||||
import org.nd4j.python4j.PythonExecutioner;
|
||||
import org.nd4j.python4j.PythonTypes;
|
||||
import org.nd4j.python4j.PythonVariable;
|
||||
|
||||
import javax.annotation.concurrent.NotThreadSafe;
|
||||
import java.util.*;
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
******************************************************************************/
|
||||
|
||||
|
||||
import org.eclipse.python4j.PythonException;
|
||||
import org.eclipse.python4j.PythonObject;
|
||||
import org.eclipse.python4j.PythonTypes;
|
||||
import org.nd4j.python4j.PythonException;
|
||||
import org.nd4j.python4j.PythonObject;
|
||||
import org.nd4j.python4j.PythonTypes;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
|
|
|
@ -16,9 +16,9 @@
|
|||
******************************************************************************/
|
||||
|
||||
|
||||
import org.eclipse.python4j.Python;
|
||||
import org.eclipse.python4j.PythonContextManager;
|
||||
import org.eclipse.python4j.PythonExecutioner;
|
||||
import org.nd4j.python4j.Python;
|
||||
import org.nd4j.python4j.PythonContextManager;
|
||||
import org.nd4j.python4j.PythonExecutioner;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import javax.annotation.concurrent.NotThreadSafe;
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
import org.eclipse.python4j.Python;
|
||||
import org.eclipse.python4j.PythonGC;
|
||||
import org.eclipse.python4j.PythonObject;
|
||||
import org.nd4j.python4j.Python;
|
||||
import org.nd4j.python4j.PythonGC;
|
||||
import org.nd4j.python4j.PythonObject;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
|
@ -49,6 +49,6 @@ public class PythonGCTest {
|
|||
PythonObject pyObjCount3 = Python.len(getObjects.call());
|
||||
long objCount3 = pyObjCount3.toLong();
|
||||
diff = objCount3 - objCount2;
|
||||
Assert.assertEquals(2, diff);// 2 objects created during function call
|
||||
Assert.assertTrue(diff <= 2);// 2 objects created during function call
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,10 +14,10 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
import org.eclipse.python4j.PythonContextManager;
|
||||
import org.eclipse.python4j.PythonJob;
|
||||
import org.eclipse.python4j.PythonTypes;
|
||||
import org.eclipse.python4j.PythonVariable;
|
||||
import org.nd4j.python4j.PythonContextManager;
|
||||
import org.nd4j.python4j.PythonJob;
|
||||
import org.nd4j.python4j.PythonTypes;
|
||||
import org.nd4j.python4j.PythonVariable;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
@ -30,7 +30,7 @@ import static org.junit.Assert.assertEquals;
|
|||
public class PythonJobTest {
|
||||
|
||||
@Test
|
||||
public void testPythonJobBasic() throws Exception{
|
||||
public void testPythonJobBasic(){
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
|
||||
String code = "c = a + b";
|
||||
|
@ -65,7 +65,7 @@ public class PythonJobTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testPythonJobReturnAllVariables()throws Exception{
|
||||
public void testPythonJobReturnAllVariables(){
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
|
||||
String code = "c = a + b";
|
||||
|
@ -101,7 +101,7 @@ public class PythonJobTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testMultiplePythonJobsParallel()throws Exception{
|
||||
public void testMultiplePythonJobsParallel(){
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
String code1 = "c = a + b";
|
||||
PythonJob job1 = new PythonJob("job1", code1, false);
|
||||
|
@ -150,7 +150,7 @@ public class PythonJobTest {
|
|||
|
||||
|
||||
@Test
|
||||
public void testPythonJobSetupRun()throws Exception{
|
||||
public void testPythonJobSetupRun(){
|
||||
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
String code = "five=None\n" +
|
||||
|
@ -189,7 +189,7 @@ public class PythonJobTest {
|
|||
|
||||
}
|
||||
@Test
|
||||
public void testPythonJobSetupRunAndReturnAllVariables()throws Exception{
|
||||
public void testPythonJobSetupRunAndReturnAllVariables(){
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
String code = "five=None\n" +
|
||||
"c=None\n"+
|
||||
|
@ -225,7 +225,7 @@ public class PythonJobTest {
|
|||
}
|
||||
|
||||
@Test
|
||||
public void testMultiplePythonJobsSetupRunParallel()throws Exception{
|
||||
public void testMultiplePythonJobsSetupRunParallel(){
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
|
||||
String code1 = "five=None\n" +
|
||||
|
|
|
@ -14,10 +14,9 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
import org.eclipse.python4j.*;
|
||||
import org.nd4j.python4j.*;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import javax.annotation.concurrent.NotThreadSafe;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
|
|
|
@ -15,12 +15,13 @@
|
|||
******************************************************************************/
|
||||
|
||||
|
||||
import org.eclipse.python4j.PythonException;
|
||||
import org.eclipse.python4j.PythonObject;
|
||||
import org.eclipse.python4j.PythonTypes;
|
||||
import org.nd4j.python4j.*;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class PythonPrimitiveTypesTest {
|
||||
|
||||
@Test
|
||||
|
@ -78,5 +79,18 @@ public class PythonPrimitiveTypesTest {
|
|||
|
||||
Assert.assertEquals(b, b3);
|
||||
}
|
||||
@Test
|
||||
public void testBytes() {
|
||||
byte[] bytes = new byte[]{97, 98, 99};
|
||||
List<PythonVariable> inputs = new ArrayList<>();
|
||||
inputs.add(new PythonVariable<>("buff", PythonTypes.BYTES, bytes));
|
||||
List<PythonVariable> outputs = new ArrayList<>();
|
||||
outputs.add(new PythonVariable<>("s1", PythonTypes.STR));
|
||||
outputs.add(new PythonVariable<>("buff2", PythonTypes.BYTES));
|
||||
String code = "s1 = ''.join(chr(c) for c in buff)\nbuff2=b'def'";
|
||||
PythonExecutioner.exec(code, inputs, outputs);
|
||||
Assert.assertEquals("abc", outputs.get(0).getValue());
|
||||
Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[])outputs.get(1).getValue());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<parent>
|
||||
<artifactId>python4j-parent</artifactId>
|
||||
<groupId>org.eclipse</groupId>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
@ -28,15 +28,50 @@
|
|||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>python4j-core</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-test-resources</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-10.2</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-cuda-10.1</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>dl4j-test-resources</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
</profiles>
|
||||
|
||||
|
||||
</project>
|
|
@ -0,0 +1,299 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
package org.nd4j.python4j;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.bytedeco.cpython.PyObject;
|
||||
import org.bytedeco.cpython.PyTypeObject;
|
||||
import org.bytedeco.javacpp.Pointer;
|
||||
import org.bytedeco.javacpp.SizeTPointer;
|
||||
import org.bytedeco.numpy.PyArrayObject;
|
||||
import org.bytedeco.numpy.global.numpy;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
|
||||
import static org.bytedeco.cpython.global.python.*;
|
||||
import static org.bytedeco.cpython.global.python.Py_DecRef;
|
||||
import static org.bytedeco.numpy.global.numpy.*;
|
||||
import static org.bytedeco.numpy.global.numpy.NPY_ARRAY_CARRAY;
|
||||
import static org.bytedeco.numpy.global.numpy.PyArray_Type;
|
||||
|
||||
@Slf4j
|
||||
public class NumpyArray extends PythonType<INDArray> {
|
||||
|
||||
public static final NumpyArray INSTANCE;
|
||||
private static final AtomicBoolean init = new AtomicBoolean(false);
|
||||
private static final Map<String, DataBuffer> cache = new HashMap<>();
|
||||
|
||||
static {
|
||||
new PythonExecutioner();
|
||||
INSTANCE = new NumpyArray();
|
||||
}
|
||||
|
||||
@Override
|
||||
public File[] packages(){
|
||||
try{
|
||||
return new File[]{numpy.cachePackage()};
|
||||
}catch(Exception e){
|
||||
throw new PythonException(e);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public synchronized void init() {
|
||||
if (init.get()) return;
|
||||
init.set(true);
|
||||
if (PythonGIL.locked()) {
|
||||
throw new PythonException("Can not initialize numpy - GIL already acquired.");
|
||||
}
|
||||
int err = numpy._import_array();
|
||||
if (err < 0){
|
||||
System.out.println("Numpy import failed!");
|
||||
throw new PythonException("Numpy import failed!");
|
||||
}
|
||||
}
|
||||
|
||||
public NumpyArray() {
|
||||
super("numpy.ndarray", INDArray.class);
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray toJava(PythonObject pythonObject) {
|
||||
log.info("Converting PythonObject to INDArray...");
|
||||
PyObject np = PyImport_ImportModule("numpy");
|
||||
PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
|
||||
if (PyObject_IsInstance(pythonObject.getNativePythonObject(), ndarray) != 1) {
|
||||
Py_DecRef(ndarray);
|
||||
Py_DecRef(np);
|
||||
throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
|
||||
}
|
||||
Py_DecRef(ndarray);
|
||||
Py_DecRef(np);
|
||||
PyArrayObject npArr = new PyArrayObject(pythonObject.getNativePythonObject());
|
||||
long[] shape = new long[PyArray_NDIM(npArr)];
|
||||
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
|
||||
if (shapePtr != null)
|
||||
shapePtr.get(shape, 0, shape.length);
|
||||
long[] strides = new long[shape.length];
|
||||
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
|
||||
if (stridesPtr != null)
|
||||
stridesPtr.get(strides, 0, strides.length);
|
||||
int npdtype = PyArray_TYPE(npArr);
|
||||
|
||||
DataType dtype;
|
||||
switch (npdtype) {
|
||||
case NPY_DOUBLE:
|
||||
dtype = DataType.DOUBLE;
|
||||
break;
|
||||
case NPY_FLOAT:
|
||||
dtype = DataType.FLOAT;
|
||||
break;
|
||||
case NPY_SHORT:
|
||||
dtype = DataType.SHORT;
|
||||
break;
|
||||
case NPY_INT:
|
||||
dtype = DataType.INT32;
|
||||
break;
|
||||
case NPY_LONG:
|
||||
dtype = DataType.INT64;
|
||||
break;
|
||||
case NPY_UINT:
|
||||
dtype = DataType.UINT32;
|
||||
break;
|
||||
case NPY_BYTE:
|
||||
dtype = DataType.INT8;
|
||||
break;
|
||||
case NPY_UBYTE:
|
||||
dtype = DataType.UINT8;
|
||||
break;
|
||||
case NPY_BOOL:
|
||||
dtype = DataType.BOOL;
|
||||
break;
|
||||
case NPY_HALF:
|
||||
dtype = DataType.FLOAT16;
|
||||
break;
|
||||
case NPY_LONGLONG:
|
||||
dtype = DataType.INT64;
|
||||
break;
|
||||
case NPY_USHORT:
|
||||
dtype = DataType.UINT16;
|
||||
break;
|
||||
case NPY_ULONG:
|
||||
case NPY_ULONGLONG:
|
||||
dtype = DataType.UINT64;
|
||||
break;
|
||||
default:
|
||||
throw new PythonException("Unsupported array data type: " + npdtype);
|
||||
}
|
||||
long size = 1;
|
||||
for (int i = 0; i < shape.length; size *= shape[i++]) ;
|
||||
|
||||
INDArray ret;
|
||||
long address = PyArray_DATA(npArr).address();
|
||||
String key = address + "_" + size + "_" + dtype;
|
||||
DataBuffer buff = cache.get(key);
|
||||
if (buff == null) {
|
||||
try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address);
|
||||
ptr = ptr.limit(size);
|
||||
ptr = ptr.capacity(size);
|
||||
buff = Nd4j.createBuffer(ptr, size, dtype);
|
||||
cache.put(key, buff);
|
||||
}
|
||||
}
|
||||
int elemSize = buff.getElementSize();
|
||||
long[] nd4jStrides = new long[strides.length];
|
||||
for (int i = 0; i < strides.length; i++) {
|
||||
nd4jStrides[i] = strides[i] / elemSize;
|
||||
}
|
||||
ret = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype);
|
||||
Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.HOST);
|
||||
log.info("Done.");
|
||||
return ret;
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public PythonObject toPython(INDArray indArray) {
|
||||
log.info("Converting INDArray to PythonObject...");
|
||||
DataType dataType = indArray.dataType();
|
||||
DataBuffer buff = indArray.data();
|
||||
String key = buff.pointer().address() + "_" + buff.length() + "_" + dataType;
|
||||
cache.put(key, buff);
|
||||
int numpyType;
|
||||
String ctype;
|
||||
switch (dataType) {
|
||||
case DOUBLE:
|
||||
numpyType = NPY_DOUBLE;
|
||||
ctype = "c_double";
|
||||
break;
|
||||
case FLOAT:
|
||||
case BFLOAT16:
|
||||
numpyType = NPY_FLOAT;
|
||||
ctype = "c_float";
|
||||
break;
|
||||
case SHORT:
|
||||
numpyType = NPY_SHORT;
|
||||
ctype = "c_short";
|
||||
break;
|
||||
case INT:
|
||||
numpyType = NPY_INT;
|
||||
ctype = "c_int";
|
||||
break;
|
||||
case LONG:
|
||||
numpyType = NPY_INT64;
|
||||
ctype = "c_int64";
|
||||
break;
|
||||
case UINT16:
|
||||
numpyType = NPY_USHORT;
|
||||
ctype = "c_uint16";
|
||||
break;
|
||||
case UINT32:
|
||||
numpyType = NPY_UINT;
|
||||
ctype = "c_uint";
|
||||
break;
|
||||
case UINT64:
|
||||
numpyType = NPY_UINT64;
|
||||
ctype = "c_uint64";
|
||||
break;
|
||||
case BOOL:
|
||||
numpyType = NPY_BOOL;
|
||||
ctype = "c_bool";
|
||||
break;
|
||||
case BYTE:
|
||||
numpyType = NPY_BYTE;
|
||||
ctype = "c_byte";
|
||||
break;
|
||||
case UBYTE:
|
||||
numpyType = NPY_UBYTE;
|
||||
ctype = "c_ubyte";
|
||||
break;
|
||||
case HALF:
|
||||
numpyType = NPY_HALF;
|
||||
ctype = "c_short";
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Unsupported dtype: " + dataType);
|
||||
}
|
||||
|
||||
long[] shape = indArray.shape();
|
||||
INDArray inputArray = indArray;
|
||||
if (dataType == DataType.BFLOAT16) {
|
||||
log.warn("Creating copy of array as bfloat16 is not supported by numpy.");
|
||||
inputArray = indArray.castTo(DataType.FLOAT);
|
||||
}
|
||||
|
||||
//Sync to host memory in the case of CUDA, before passing the host memory pointer to Python
|
||||
|
||||
Nd4j.getAffinityManager().ensureLocation(inputArray, AffinityManager.Location.HOST);
|
||||
|
||||
// PyArray_Type() call causes jvm crash in linux cpu if GIL is acquired by non main thread.
|
||||
// Using Interpreter for now:
|
||||
|
||||
// try(PythonContextManager.Context context = new PythonContextManager.Context("__np_array_converter")){
|
||||
// log.info("Stringing exec...");
|
||||
// String code = "import ctypes\nimport numpy as np\n" +
|
||||
// "cArr = (ctypes." + ctype + "*" + indArray.length() + ")"+
|
||||
// ".from_address(" + indArray.data().pointer().address() + ")\n"+
|
||||
// "npArr = np.frombuffer(cArr, dtype=" + ((numpyType == NPY_HALF) ? "'half'" : "ctypes." + ctype)+
|
||||
// ").reshape(" + Arrays.toString(indArray.shape()) + ")";
|
||||
// PythonExecutioner.exec(code);
|
||||
// log.info("exec done.");
|
||||
// PythonObject ret = PythonExecutioner.getVariable("npArr");
|
||||
// Py_IncRef(ret.getNativePythonObject());
|
||||
// return ret;
|
||||
//
|
||||
// }
|
||||
log.info("NUMPY: PyArray_Type()");
|
||||
PyTypeObject pyTypeObject = PyArray_Type();
|
||||
|
||||
|
||||
log.info("NUMPY: PyArray_New()");
|
||||
PyObject npArr = PyArray_New(pyTypeObject, shape.length, new SizeTPointer(shape),
|
||||
numpyType, null,
|
||||
inputArray.data().addressPointer(),
|
||||
0, NPY_ARRAY_CARRAY, null);
|
||||
log.info("Done.");
|
||||
return new PythonObject(npArr);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean accepts(Object javaObject) {
|
||||
return javaObject instanceof INDArray;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray adapt(Object javaObject) {
|
||||
if (javaObject instanceof INDArray) {
|
||||
return (INDArray) javaObject;
|
||||
}
|
||||
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to INDArray");
|
||||
}
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
org.nd4j.python4j.NumpyArray
|
|
@ -0,0 +1,169 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
import org.nd4j.python4j.*;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.nativeblas.OpaqueDataBuffer;
|
||||
|
||||
import javax.annotation.concurrent.NotThreadSafe;
|
||||
import java.lang.reflect.Method;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collection;
|
||||
import java.util.List;
|
||||
|
||||
@NotThreadSafe
|
||||
@RunWith(Parameterized.class)
|
||||
public class PythonNumpyBasicTest {
|
||||
private DataType dataType;
|
||||
private long[] shape;
|
||||
|
||||
public PythonNumpyBasicTest(DataType dataType, long[] shape, String dummyArg) {
|
||||
this.dataType = dataType;
|
||||
this.shape = shape;
|
||||
}
|
||||
|
||||
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}, shape={2}")
|
||||
public static Collection params() {
|
||||
DataType[] types = new DataType[] {
|
||||
DataType.BOOL,
|
||||
DataType.FLOAT16,
|
||||
DataType.BFLOAT16,
|
||||
DataType.FLOAT,
|
||||
DataType.DOUBLE,
|
||||
DataType.INT8,
|
||||
DataType.INT16,
|
||||
DataType.INT32,
|
||||
DataType.INT64,
|
||||
DataType.UINT8,
|
||||
DataType.UINT16,
|
||||
DataType.UINT32,
|
||||
DataType.UINT64
|
||||
};
|
||||
|
||||
long[][] shapes = new long[][]{
|
||||
new long[]{2, 3},
|
||||
new long[]{3},
|
||||
new long[]{1},
|
||||
new long[]{} // scalar
|
||||
};
|
||||
|
||||
|
||||
List<Object[]> ret = new ArrayList<>();
|
||||
for (DataType type: types){
|
||||
for (long[] shape: shapes){
|
||||
ret.add(new Object[]{type, shape, Arrays.toString(shape)});
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testConversion(){
|
||||
INDArray arr = Nd4j.zeros(dataType, shape);
|
||||
PythonObject npArr = PythonTypes.convert(arr);
|
||||
INDArray arr2 = PythonTypes.<INDArray>getPythonTypeForPythonObject(npArr).toJava(npArr);
|
||||
if (dataType == DataType.BFLOAT16){
|
||||
arr = arr.castTo(DataType.FLOAT);
|
||||
}
|
||||
Assert.assertEquals(arr,arr2);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testExecution(){
|
||||
List<PythonVariable> inputs = new ArrayList<>();
|
||||
INDArray x = Nd4j.ones(dataType, shape);
|
||||
INDArray y = Nd4j.zeros(dataType, shape);
|
||||
INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2));
|
||||
z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z;
|
||||
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
|
||||
inputs.add(new PythonVariable<>("x", arrType, x));
|
||||
inputs.add(new PythonVariable<>("y", arrType, y));
|
||||
List<PythonVariable> outputs = new ArrayList<>();
|
||||
PythonVariable<INDArray> output = new PythonVariable<>("z", arrType);
|
||||
outputs.add(output);
|
||||
String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)";
|
||||
if (shape.length == 0){ // scalar special case
|
||||
code += "\nimport numpy as np\nz = np.asarray(float(z), dtype=x.dtype)";
|
||||
}
|
||||
PythonExecutioner.exec(code, inputs, outputs);
|
||||
INDArray z2 = output.getValue();
|
||||
|
||||
Assert.assertEquals(z.dataType(), z2.dataType());
|
||||
Assert.assertEquals(z, z2);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testInplaceExecution(){
|
||||
if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return;
|
||||
if (shape.length == 0) return;
|
||||
List<PythonVariable> inputs = new ArrayList<>();
|
||||
INDArray x = Nd4j.ones(dataType, shape);
|
||||
INDArray y = Nd4j.zeros(dataType, shape);
|
||||
INDArray z = x.mul(y.add(2));
|
||||
// Nd4j.getAffinityManager().ensureLocation(z, AffinityManager.Location.HOST);
|
||||
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
|
||||
inputs.add(new PythonVariable<>("x", arrType, x));
|
||||
inputs.add(new PythonVariable<>("y", arrType, y));
|
||||
List<PythonVariable> outputs = new ArrayList<>();
|
||||
PythonVariable<INDArray> output = new PythonVariable<>("x", arrType);
|
||||
outputs.add(output);
|
||||
String code = "x *= y + 2";
|
||||
PythonExecutioner.exec(code, inputs, outputs);
|
||||
INDArray z2 = output.getValue();
|
||||
Assert.assertEquals(x.dataType(), z2.dataType());
|
||||
Assert.assertEquals(z.dataType(), z2.dataType());
|
||||
Assert.assertEquals(x, z2);
|
||||
Assert.assertEquals(z, z2);
|
||||
Assert.assertEquals(x.data().pointer().address(), z2.data().pointer().address());
|
||||
if("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
|
||||
Assert.assertEquals(getDeviceAddress(x), getDeviceAddress(z2));
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
private static long getDeviceAddress(INDArray array){
|
||||
if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
|
||||
throw new IllegalStateException("Cannot ge device pointer for non-CUDA device");
|
||||
}
|
||||
|
||||
//Use reflection here as OpaqueDataBuffer is only available on BaseCudaDataBuffer and BaseCpuDataBuffer - not DataBuffer/BaseDataBuffer
|
||||
// due to it being defined in nd4j-native-api, not nd4j-api
|
||||
try {
|
||||
Class<?> c = Class.forName("org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer");
|
||||
Method m = c.getMethod("getOpaqueDataBuffer");
|
||||
OpaqueDataBuffer db = (OpaqueDataBuffer) m.invoke(array.data());
|
||||
long address = db.specialBuffer().address();
|
||||
return address;
|
||||
} catch (Throwable t){
|
||||
throw new RuntimeException("Error getting OpaqueDataBuffer", t);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -0,0 +1,96 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
|
||||
import org.nd4j.python4j.PythonException;
|
||||
import org.nd4j.python4j.PythonObject;
|
||||
import org.nd4j.python4j.PythonTypes;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import javax.annotation.concurrent.NotThreadSafe;
|
||||
import java.util.*;
|
||||
|
||||
|
||||
@NotThreadSafe
|
||||
@RunWith(Parameterized.class)
|
||||
public class PythonNumpyCollectionsTest {
|
||||
private DataType dataType;
|
||||
|
||||
public PythonNumpyCollectionsTest(DataType dataType){
|
||||
this.dataType = dataType;
|
||||
}
|
||||
|
||||
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
|
||||
public static DataType[] params() {
|
||||
return new DataType[]{
|
||||
DataType.BOOL,
|
||||
DataType.FLOAT16,
|
||||
//DataType.BFLOAT16,
|
||||
DataType.FLOAT,
|
||||
DataType.DOUBLE,
|
||||
DataType.INT8,
|
||||
DataType.INT16,
|
||||
DataType.INT32,
|
||||
DataType.INT64,
|
||||
DataType.UINT8,
|
||||
DataType.UINT16,
|
||||
DataType.UINT32,
|
||||
DataType.UINT64
|
||||
};
|
||||
}
|
||||
@Test
|
||||
public void testPythonDictFromMap() throws PythonException {
|
||||
Map map = new HashMap();
|
||||
map.put("a", 1);
|
||||
map.put(1, "a");
|
||||
map.put("arr", Nd4j.ones(dataType, 2, 3));
|
||||
map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType,3,2)));
|
||||
Map innerMap = new HashMap();
|
||||
innerMap.put("b", 2);
|
||||
innerMap.put(2, "b");
|
||||
innerMap.put(5, Nd4j.ones(dataType, 5));
|
||||
map.put("innermap", innerMap);
|
||||
map.put("list2", Arrays.asList(4, "5", innerMap, false, true));
|
||||
PythonObject dict = PythonTypes.convert(map);
|
||||
Map map2 = PythonTypes.DICT.toJava(dict);
|
||||
Assert.assertEquals(map.toString(), map2.toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testPythonListFromList() throws PythonException{
|
||||
List<Object> list = new ArrayList<>();
|
||||
list.add(1);
|
||||
list.add("2");
|
||||
list.add(Nd4j.ones(dataType, 2, 3));
|
||||
list.add(Arrays.asList("a",
|
||||
Nd4j.ones(dataType, 1, 2),1.0, 2f, 10, true, false,
|
||||
Nd4j.zeros(dataType, 3, 2)));
|
||||
Map map = new HashMap();
|
||||
map.put("a", 1);
|
||||
map.put(1, "a");
|
||||
map.put(5, Nd4j.ones(dataType,4, 5));
|
||||
map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType, 3, 1)));
|
||||
list.add(map);
|
||||
PythonObject dict = PythonTypes.convert(list);
|
||||
List list2 = PythonTypes.LIST.toJava(dict);
|
||||
Assert.assertEquals(list.toString(), list2.toString());
|
||||
}
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
import org.nd4j.python4j.Python;
|
||||
import org.nd4j.python4j.PythonGC;
|
||||
import org.nd4j.python4j.PythonObject;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import javax.annotation.concurrent.NotThreadSafe;
|
||||
|
||||
|
||||
@NotThreadSafe
|
||||
public class PythonNumpyGCTest {
|
||||
|
||||
@Test
|
||||
public void testGC(){
|
||||
PythonObject gcModule = Python.importModule("gc");
|
||||
PythonObject getObjects = gcModule.attr("get_objects");
|
||||
PythonObject pyObjCount1 = Python.len(getObjects.call());
|
||||
long objCount1 = pyObjCount1.toLong();
|
||||
PythonObject pyList = Python.list();
|
||||
pyList.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10)));
|
||||
pyList.attr("append").call(1.0);
|
||||
pyList.attr("append").call(true);
|
||||
PythonObject pyObjCount2 = Python.len(getObjects.call());
|
||||
long objCount2 = pyObjCount2.toLong();
|
||||
long diff = objCount2 - objCount1;
|
||||
Assert.assertTrue(diff > 2);
|
||||
try(PythonGC gc = PythonGC.watch()){
|
||||
PythonObject pyList2 = Python.list();
|
||||
pyList2.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10)));
|
||||
pyList2.attr("append").call(1.0);
|
||||
pyList2.attr("append").call(true);
|
||||
}
|
||||
PythonObject pyObjCount3 = Python.len(getObjects.call());
|
||||
long objCount3 = pyObjCount3.toLong();
|
||||
diff = objCount3 - objCount2;
|
||||
Assert.assertTrue(diff <= 2);// 2 objects created during function call
|
||||
}
|
||||
}
|
|
@ -0,0 +1,22 @@
|
|||
import org.nd4j.python4j.NumpyArray;
|
||||
import org.nd4j.python4j.Python;
|
||||
import org.nd4j.python4j.PythonGC;
|
||||
import org.nd4j.python4j.PythonObject;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
public class PythonNumpyImportTest {
|
||||
|
||||
@Test
|
||||
public void testNumpyImport(){
|
||||
try(PythonGC gc = PythonGC.watch()){
|
||||
PythonObject np = Python.importModule("numpy");
|
||||
PythonObject zeros = np.attr("zeros").call(5);
|
||||
INDArray arr = NumpyArray.INSTANCE.toJava(zeros);
|
||||
Assert.assertEquals(arr, Nd4j.zeros(DataType.DOUBLE, 5));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,303 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.python4j.*;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
|
||||
@javax.annotation.concurrent.NotThreadSafe
|
||||
@RunWith(Parameterized.class)
|
||||
public class PythonNumpyJobTest {
|
||||
private DataType dataType;
|
||||
|
||||
public PythonNumpyJobTest(DataType dataType){
|
||||
this.dataType = dataType;
|
||||
}
|
||||
|
||||
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
|
||||
public static DataType[] params() {
|
||||
return new DataType[]{
|
||||
DataType.BOOL,
|
||||
DataType.FLOAT16,
|
||||
DataType.BFLOAT16,
|
||||
DataType.FLOAT,
|
||||
DataType.DOUBLE,
|
||||
DataType.INT8,
|
||||
DataType.INT16,
|
||||
DataType.INT32,
|
||||
DataType.INT64,
|
||||
DataType.UINT8,
|
||||
DataType.UINT16,
|
||||
DataType.UINT32,
|
||||
DataType.UINT64
|
||||
};
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNumpyJobBasic(){
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
List<PythonVariable> inputs = new ArrayList<>();
|
||||
INDArray x = Nd4j.ones(dataType, 2, 3);
|
||||
INDArray y = Nd4j.zeros(dataType, 2, 3);
|
||||
INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2));
|
||||
z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z;
|
||||
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
|
||||
inputs.add(new PythonVariable<>("x", arrType, x));
|
||||
inputs.add(new PythonVariable<>("y", arrType, y));
|
||||
List<PythonVariable> outputs = new ArrayList<>();
|
||||
PythonVariable<INDArray> output = new PythonVariable<>("z", arrType);
|
||||
outputs.add(output);
|
||||
String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)";
|
||||
|
||||
PythonJob job = new PythonJob("job1", code, false);
|
||||
|
||||
job.exec(inputs, outputs);
|
||||
|
||||
INDArray z2 = output.getValue();
|
||||
|
||||
if (dataType == DataType.BFLOAT16){
|
||||
z2 = z2.castTo(DataType.FLOAT);
|
||||
}
|
||||
|
||||
Assert.assertEquals(z, z2);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNumpyJobReturnAllVariables(){
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
List<PythonVariable> inputs = new ArrayList<>();
|
||||
INDArray x = Nd4j.ones(dataType, 2, 3);
|
||||
INDArray y = Nd4j.zeros(dataType, 2, 3);
|
||||
INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2));
|
||||
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
|
||||
inputs.add(new PythonVariable<>("x", arrType, x));
|
||||
inputs.add(new PythonVariable<>("y", arrType, y));
|
||||
String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)";
|
||||
|
||||
PythonJob job = new PythonJob("job1", code, false);
|
||||
List<PythonVariable> outputs = job.execAndReturnAllVariables(inputs);
|
||||
|
||||
INDArray x2 = (INDArray) outputs.get(0).getValue();
|
||||
INDArray y2 = (INDArray) outputs.get(1).getValue();
|
||||
INDArray z2 = (INDArray) outputs.get(2).getValue();
|
||||
|
||||
if (dataType == DataType.BFLOAT16){
|
||||
x = x.castTo(DataType.FLOAT);
|
||||
y = y.castTo(DataType.FLOAT);
|
||||
z = z.castTo(DataType.FLOAT);
|
||||
}
|
||||
Assert.assertEquals(x, x2);
|
||||
Assert.assertEquals(y, y2);
|
||||
Assert.assertEquals(z, z2);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testMultipleNumpyJobsParallel(){
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
String code1 =(dataType == DataType.BOOL)?"z = x":"z = x + y";
|
||||
PythonJob job1 = new PythonJob("job1", code1, false);
|
||||
|
||||
String code2 =(dataType == DataType.BOOL)?"z = y":"z = x - y";
|
||||
PythonJob job2 = new PythonJob("job2", code2, false);
|
||||
|
||||
List<PythonVariable> inputs = new ArrayList<>();
|
||||
INDArray x = Nd4j.ones(dataType, 2, 3);
|
||||
INDArray y = Nd4j.zeros(dataType, 2, 3);
|
||||
INDArray z1 = (dataType == DataType.BOOL)?x:x.add(y);
|
||||
z1 = (dataType == DataType.BFLOAT16)? z1.castTo(DataType.FLOAT): z1;
|
||||
INDArray z2 = (dataType == DataType.BOOL)?y:x.sub(y);
|
||||
z2 = (dataType == DataType.BFLOAT16)? z2.castTo(DataType.FLOAT): z2;
|
||||
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
|
||||
inputs.add(new PythonVariable<>("x", arrType, x));
|
||||
inputs.add(new PythonVariable<>("y", arrType, y));
|
||||
|
||||
|
||||
List<PythonVariable> outputs = new ArrayList<>();
|
||||
|
||||
outputs.add(new PythonVariable<>("z", arrType));
|
||||
|
||||
job1.exec(inputs, outputs);
|
||||
|
||||
assertEquals(z1, outputs.get(0).getValue());
|
||||
|
||||
|
||||
job2.exec(inputs, outputs);
|
||||
|
||||
assertEquals(z2, outputs.get(0).getValue());
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public synchronized void testNumpyJobSetupRun(){
|
||||
if (dataType == DataType.BOOL)return;
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
String code = "five=None\n" +
|
||||
"def setup():\n" +
|
||||
" global five\n"+
|
||||
" five = 5\n\n" +
|
||||
"def run(a, b):\n" +
|
||||
" c = a + b + five\n"+
|
||||
" return {'c':c}\n\n";
|
||||
|
||||
PythonJob job = new PythonJob("job1", code, true);
|
||||
|
||||
List<PythonVariable> inputs = new ArrayList<>();
|
||||
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2)));
|
||||
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||
|
||||
List<PythonVariable> outputs = new ArrayList<>();
|
||||
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
|
||||
job.exec(inputs, outputs);
|
||||
|
||||
|
||||
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10),
|
||||
outputs.get(0).getValue());
|
||||
|
||||
|
||||
inputs = new ArrayList<>();
|
||||
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
|
||||
|
||||
|
||||
outputs = new ArrayList<>();
|
||||
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
|
||||
|
||||
job.exec(inputs, outputs);
|
||||
|
||||
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12),
|
||||
outputs.get(0).getValue());
|
||||
|
||||
|
||||
}
|
||||
@Test
|
||||
public void testNumpyJobSetupRunAndReturnAllVariables(){
|
||||
if (dataType == DataType.BOOL)return;
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
String code = "five=None\n" +
|
||||
"c=None\n"+
|
||||
"def setup():\n" +
|
||||
" global five\n"+
|
||||
" five = 5\n\n" +
|
||||
"def run(a, b):\n" +
|
||||
" global c\n" +
|
||||
" c = a + b + five\n";
|
||||
PythonJob job = new PythonJob("job1", code, true);
|
||||
|
||||
List<PythonVariable> inputs = new ArrayList<>();
|
||||
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2)));
|
||||
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||
|
||||
List<PythonVariable> outputs = job.execAndReturnAllVariables(inputs);
|
||||
|
||||
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10),
|
||||
outputs.get(1).getValue());
|
||||
|
||||
|
||||
inputs = new ArrayList<>();
|
||||
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
|
||||
|
||||
|
||||
outputs = job.execAndReturnAllVariables(inputs);
|
||||
|
||||
|
||||
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12),
|
||||
outputs.get(1).getValue());
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultipleNumpyJobsSetupRunParallel(){
|
||||
if (dataType == DataType.BOOL)return;
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
|
||||
String code1 = "five=None\n" +
|
||||
"def setup():\n" +
|
||||
" global five\n"+
|
||||
" five = 5\n\n" +
|
||||
"def run(a, b):\n" +
|
||||
" c = a + b + five\n"+
|
||||
" return {'c':c}\n\n";
|
||||
PythonJob job1 = new PythonJob("job1", code1, true);
|
||||
|
||||
String code2 = "five=None\n" +
|
||||
"def setup():\n" +
|
||||
" global five\n"+
|
||||
" five = 5\n\n" +
|
||||
"def run(a, b):\n" +
|
||||
" c = a + b - five\n"+
|
||||
" return {'c':c}\n\n";
|
||||
PythonJob job2 = new PythonJob("job2", code2, true);
|
||||
|
||||
List<PythonVariable> inputs = new ArrayList<>();
|
||||
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2)));
|
||||
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||
|
||||
|
||||
List<PythonVariable> outputs = new ArrayList<>();
|
||||
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
|
||||
|
||||
job1.exec(inputs, outputs);
|
||||
|
||||
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10),
|
||||
outputs.get(0).getValue());
|
||||
|
||||
|
||||
job2.exec(inputs, outputs);
|
||||
|
||||
assertEquals(Nd4j.zeros((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3),
|
||||
outputs.get(0).getValue());
|
||||
|
||||
|
||||
inputs = new ArrayList<>();
|
||||
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
|
||||
|
||||
outputs = new ArrayList<>();
|
||||
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
|
||||
|
||||
|
||||
job1.exec(inputs, outputs);
|
||||
|
||||
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12),
|
||||
outputs.get(0).getValue());
|
||||
|
||||
|
||||
job2.exec(inputs, outputs);
|
||||
|
||||
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(2),
|
||||
outputs.get(0).getValue());
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,194 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
import org.nd4j.python4j.*;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.junit.runners.Parameterized;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import javax.annotation.concurrent.NotThreadSafe;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
@NotThreadSafe
|
||||
@RunWith(Parameterized.class)
|
||||
public class PythonNumpyMultiThreadTest {
|
||||
private DataType dataType;
|
||||
|
||||
public PythonNumpyMultiThreadTest(DataType dataType) {
|
||||
this.dataType = dataType;
|
||||
}
|
||||
|
||||
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
|
||||
public static DataType[] params() {
|
||||
return new DataType[]{
|
||||
// DataType.BOOL,
|
||||
// DataType.FLOAT16,
|
||||
// DataType.BFLOAT16,
|
||||
DataType.FLOAT,
|
||||
DataType.DOUBLE,
|
||||
// DataType.INT8,
|
||||
// DataType.INT16,
|
||||
DataType.INT32,
|
||||
DataType.INT64,
|
||||
// DataType.UINT8,
|
||||
// DataType.UINT16,
|
||||
// DataType.UINT32,
|
||||
// DataType.UINT64
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testMultiThreading1() throws Throwable {
|
||||
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
|
||||
Runnable runnable = new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
try (PythonGIL gil = PythonGIL.lock()) {
|
||||
try (PythonGC gc = PythonGC.watch()) {
|
||||
List<PythonVariable> inputs = new ArrayList<>();
|
||||
inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||
inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
|
||||
PythonVariable out = new PythonVariable<>("z", NumpyArray.INSTANCE);
|
||||
String code = "z = x + y";
|
||||
PythonExecutioner.exec(code, inputs, Collections.singletonList(out));
|
||||
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), out.getValue());
|
||||
}
|
||||
} catch (Throwable e) {
|
||||
exceptions.add(e);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int numThreads = 10;
|
||||
Thread[] threads = new Thread[numThreads];
|
||||
for (int i = 0; i < threads.length; i++) {
|
||||
threads[i] = new Thread(runnable);
|
||||
}
|
||||
for (int i = 0; i < threads.length; i++) {
|
||||
threads[i].start();
|
||||
}
|
||||
Thread.sleep(100);
|
||||
for (int i = 0; i < threads.length; i++) {
|
||||
threads[i].join();
|
||||
}
|
||||
if (!exceptions.isEmpty()) {
|
||||
throw (exceptions.get(0));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultiThreading2() throws Throwable {
|
||||
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
|
||||
Runnable runnable = new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
try (PythonGIL gil = PythonGIL.lock()) {
|
||||
try (PythonGC gc = PythonGC.watch()) {
|
||||
PythonContextManager.reset();
|
||||
List<PythonVariable> inputs = new ArrayList<>();
|
||||
inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||
inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
|
||||
String code = "z = x + y";
|
||||
List<PythonVariable> outputs = PythonExecutioner.execAndReturnAllVariables(code, inputs);
|
||||
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(3), outputs.get(0).getValue());
|
||||
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(4), outputs.get(1).getValue());
|
||||
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), outputs.get(2).getValue());
|
||||
}
|
||||
} catch (Throwable e) {
|
||||
exceptions.add(e);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
int numThreads = 10;
|
||||
Thread[] threads = new Thread[numThreads];
|
||||
for (int i = 0; i < threads.length; i++) {
|
||||
threads[i] = new Thread(runnable);
|
||||
}
|
||||
for (int i = 0; i < threads.length; i++) {
|
||||
threads[i].start();
|
||||
}
|
||||
Thread.sleep(100);
|
||||
for (int i = 0; i < threads.length; i++) {
|
||||
threads[i].join();
|
||||
}
|
||||
if (!exceptions.isEmpty()) {
|
||||
throw (exceptions.get(0));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultiThreading3() throws Throwable {
|
||||
PythonContextManager.deleteNonMainContexts();
|
||||
|
||||
String code = "c = a + b";
|
||||
final PythonJob job = new PythonJob("job1", code, false);
|
||||
|
||||
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
|
||||
|
||||
class JobThread extends Thread {
|
||||
private INDArray a, b, c;
|
||||
|
||||
public JobThread(INDArray a, INDArray b, INDArray c) {
|
||||
this.a = a;
|
||||
this.b = b;
|
||||
this.c = c;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
PythonVariable<INDArray> out = new PythonVariable<>("c", NumpyArray.INSTANCE);
|
||||
job.exec(Arrays.<PythonVariable>asList(new PythonVariable<>("a", NumpyArray.INSTANCE, a),
|
||||
new PythonVariable<>("b", NumpyArray.INSTANCE, b)),
|
||||
Collections.<PythonVariable>singletonList(out));
|
||||
Assert.assertEquals(c, out.getValue());
|
||||
} catch (Exception e) {
|
||||
exceptions.add(e);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
int numThreads = 10;
|
||||
JobThread[] threads = new JobThread[numThreads];
|
||||
for (int i = 0; i < threads.length; i++) {
|
||||
threads[i] = new JobThread(Nd4j.zeros(dataType, 2, 3).add(i), Nd4j.zeros(dataType, 2, 3).add(i + 3),
|
||||
Nd4j.zeros(dataType, 2, 3).add(2 * i + 3));
|
||||
}
|
||||
|
||||
for (int i = 0; i < threads.length; i++) {
|
||||
threads[i].start();
|
||||
}
|
||||
Thread.sleep(100);
|
||||
for (int i = 0; i < threads.length; i++) {
|
||||
threads[i].join();
|
||||
}
|
||||
|
||||
if (!exceptions.isEmpty()) {
|
||||
throw (exceptions.get(0));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
|
@ -14,15 +14,22 @@
|
|||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
||||
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.python4j.NumpyArray;
|
||||
import org.nd4j.python4j.PythonTypes;
|
||||
|
||||
/**
|
||||
* An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface TargetQNetworkSource extends QNetworkSource {
|
||||
IDQN getTargetQNetwork();
|
||||
import javax.annotation.concurrent.NotThreadSafe;
|
||||
|
||||
@NotThreadSafe
|
||||
public class PythonNumpyServiceLoaderTest {
|
||||
|
||||
@Test
|
||||
public void testServiceLoader(){
|
||||
Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.<INDArray>get("numpy.ndarray"));
|
||||
Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.getPythonTypeForJavaObject(Nd4j.zeros(1)));
|
||||
}
|
||||
}
|
|
@ -17,7 +17,6 @@ package org.deeplearning4j.rl4j.agent.update;
|
|||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
||||
|
@ -28,13 +27,10 @@ import java.util.List;
|
|||
|
||||
// Temporary class that will be replaced with a more generic class that delegates gradient computation
|
||||
// and network update to sub components.
|
||||
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>, TargetQNetworkSource {
|
||||
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> {
|
||||
|
||||
@Getter
|
||||
private final IDQN qNetwork;
|
||||
|
||||
@Getter
|
||||
private IDQN targetQNetwork;
|
||||
private final IDQN targetQNetwork;
|
||||
private final int targetUpdateFrequency;
|
||||
|
||||
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
|
||||
|
@ -47,8 +43,8 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
|
|||
this.targetQNetwork = qNetwork.clone();
|
||||
this.targetUpdateFrequency = targetUpdateFrequency;
|
||||
tdTargetAlgorithm = isDoubleDQN
|
||||
? new DoubleDQN(this, gamma, errorClamp)
|
||||
: new StandardDQN(this, gamma, errorClamp);
|
||||
? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp)
|
||||
: new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -56,7 +52,7 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
|
|||
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
|
||||
qNetwork.fit(targets.getFeatures(), targets.getLabels());
|
||||
if(++updateCount % targetUpdateFrequency == 0) {
|
||||
targetQNetwork = qNetwork.clone();
|
||||
targetQNetwork.copy(qNetwork);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,8 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
|
@ -28,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
*/
|
||||
public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
|
||||
|
||||
private final TargetQNetworkSource qTargetNetworkSource;
|
||||
private final IOutputNeuralNet targetQNetwork;
|
||||
|
||||
/**
|
||||
* In litterature, this corresponds to Q{net}(s(t+1), a)
|
||||
|
@ -40,23 +39,21 @@ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
|
|||
*/
|
||||
protected INDArray targetQNetworkNextObservation;
|
||||
|
||||
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
||||
super(qTargetNetworkSource, gamma);
|
||||
this.qTargetNetworkSource = qTargetNetworkSource;
|
||||
protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
|
||||
super(qNetwork, gamma);
|
||||
this.targetQNetwork = targetQNetwork;
|
||||
}
|
||||
|
||||
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
||||
super(qTargetNetworkSource, gamma, errorClamp);
|
||||
this.qTargetNetworkSource = qTargetNetworkSource;
|
||||
protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
|
||||
super(qNetwork, gamma, errorClamp);
|
||||
this.targetQNetwork = targetQNetwork;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void initComputation(INDArray observations, INDArray nextObservations) {
|
||||
super.initComputation(observations, nextObservations);
|
||||
|
||||
qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations);
|
||||
|
||||
IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork();
|
||||
qNetworkNextObservation = qNetwork.output(nextObservations);
|
||||
targetQNetworkNextObservation = targetQNetwork.output(nextObservations);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
|
||||
|
@ -30,7 +30,7 @@ import java.util.List;
|
|||
*/
|
||||
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
|
||||
|
||||
protected final QNetworkSource qNetworkSource;
|
||||
protected final IOutputNeuralNet qNetwork;
|
||||
protected final double gamma;
|
||||
|
||||
private final double errorClamp;
|
||||
|
@ -38,12 +38,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
|||
|
||||
/**
|
||||
*
|
||||
* @param qNetworkSource The source of the Q-Network
|
||||
* @param qNetwork The Q-Network
|
||||
* @param gamma The discount factor
|
||||
* @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping.
|
||||
*/
|
||||
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) {
|
||||
this.qNetworkSource = qNetworkSource;
|
||||
protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma, double errorClamp) {
|
||||
this.qNetwork = qNetwork;
|
||||
this.gamma = gamma;
|
||||
|
||||
this.errorClamp = errorClamp;
|
||||
|
@ -52,12 +52,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
|||
|
||||
/**
|
||||
*
|
||||
* @param qNetworkSource The source of the Q-Network
|
||||
* @param qNetwork The Q-Network
|
||||
* @param gamma The discount factor
|
||||
* Note: Error clamping is disabled with this ctor
|
||||
*/
|
||||
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma) {
|
||||
this(qNetworkSource, gamma, Double.NaN);
|
||||
protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma) {
|
||||
this(qNetwork, gamma, Double.NaN);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -89,8 +89,7 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
|||
|
||||
initComputation(observations, nextObservations);
|
||||
|
||||
INDArray updatedQValues = qNetworkSource.getQNetwork().output(observations);
|
||||
|
||||
INDArray updatedQValues = qNetwork.output(observations);
|
||||
for (int i = 0; i < size; ++i) {
|
||||
Transition<Integer> transition = transitions.get(i);
|
||||
double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal());
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -32,12 +32,12 @@ public class DoubleDQN extends BaseDQNAlgorithm {
|
|||
// In litterature, this corresponds to: max_{a}Q(s_{t+1}, a)
|
||||
private INDArray maxActionsFromQNetworkNextObservation;
|
||||
|
||||
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
||||
super(qTargetNetworkSource, gamma);
|
||||
public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
|
||||
super(qNetwork, targetQNetwork, gamma);
|
||||
}
|
||||
|
||||
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
||||
super(qTargetNetworkSource, gamma, errorClamp);
|
||||
public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
|
||||
super(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
|
@ -32,12 +32,12 @@ public class StandardDQN extends BaseDQNAlgorithm {
|
|||
// In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a)
|
||||
private INDArray maxActionsFromQTargetNextObservation;
|
||||
|
||||
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
||||
super(qTargetNetworkSource, gamma);
|
||||
public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
|
||||
super(qNetwork, targetQNetwork, gamma);
|
||||
}
|
||||
|
||||
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
||||
super(qTargetNetworkSource, gamma, errorClamp);
|
||||
public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
|
||||
super(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -1,28 +1,38 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
||||
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
|
||||
/**
|
||||
* An interface for all implementations capable of supplying a Q-Network
|
||||
*
|
||||
* @author Alexandre Boulanger
|
||||
*/
|
||||
public interface QNetworkSource {
|
||||
IDQN getQNetwork();
|
||||
}
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.deeplearning4j.rl4j.network;
|
||||
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
/**
|
||||
* An interface defining the output aspect of a {@link NeuralNet}.
|
||||
*/
|
||||
public interface IOutputNeuralNet {
|
||||
/**
|
||||
* Compute the output for the supplied observation.
|
||||
* @param observation An {@link Observation}
|
||||
* @return The ouptut of the network
|
||||
*/
|
||||
INDArray output(Observation observation);
|
||||
|
||||
/**
|
||||
* Compute the output for the supplied batch.
|
||||
* @param batch
|
||||
* @return The ouptut of the network
|
||||
*/
|
||||
INDArray output(INDArray batch);
|
||||
}
|
|
@ -17,6 +17,7 @@
|
|||
package org.deeplearning4j.rl4j.network.dqn;
|
||||
|
||||
import org.deeplearning4j.nn.gradient.Gradient;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
|||
* This neural net quantify the value of each action given a state
|
||||
*
|
||||
*/
|
||||
public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
|
||||
public interface IDQN<NN extends IDQN> extends NeuralNet<NN>, IOutputNeuralNet {
|
||||
|
||||
boolean isRecurrent();
|
||||
|
||||
|
@ -37,9 +38,6 @@ public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
|
|||
|
||||
void fit(INDArray input, INDArray[] labels);
|
||||
|
||||
INDArray output(INDArray batch);
|
||||
INDArray output(Observation observation);
|
||||
|
||||
INDArray[] outputAll(INDArray batch);
|
||||
|
||||
NN clone();
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -13,16 +16,29 @@ import java.util.ArrayList;
|
|||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class DoubleDQNTest {
|
||||
|
||||
@Mock
|
||||
IOutputNeuralNet qNetworkMock;
|
||||
|
||||
@Mock
|
||||
IOutputNeuralNet targetQNetworkMock;
|
||||
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
||||
|
||||
// Assemble
|
||||
MockDQN qNetwork = new MockDQN();
|
||||
MockDQN targetQNetwork = new MockDQN();
|
||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
|
@ -31,7 +47,7 @@ public class DoubleDQNTest {
|
|||
}
|
||||
};
|
||||
|
||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
||||
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
|
@ -46,9 +62,7 @@ public class DoubleDQNTest {
|
|||
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
||||
|
||||
// Assemble
|
||||
MockDQN qNetwork = new MockDQN();
|
||||
MockDQN targetQNetwork = new MockDQN(-1.0);
|
||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
|
||||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
|
@ -57,7 +71,7 @@ public class DoubleDQNTest {
|
|||
}
|
||||
};
|
||||
|
||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
||||
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
|
@ -72,9 +86,7 @@ public class DoubleDQNTest {
|
|||
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
||||
|
||||
// Assemble
|
||||
MockDQN qNetwork = new MockDQN();
|
||||
MockDQN targetQNetwork = new MockDQN(-1.0);
|
||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
|
||||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
|
@ -87,7 +99,7 @@ public class DoubleDQNTest {
|
|||
}
|
||||
};
|
||||
|
||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
||||
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
||||
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||
import org.deeplearning4j.rl4j.observation.Observation;
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.junit.runner.RunWith;
|
||||
import org.mockito.Mock;
|
||||
import org.mockito.junit.MockitoJUnitRunner;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.dataset.api.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
@ -12,17 +15,31 @@ import org.nd4j.linalg.factory.Nd4j;
|
|||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
@RunWith(MockitoJUnitRunner.class)
|
||||
public class StandardDQNTest {
|
||||
|
||||
@Mock
|
||||
IOutputNeuralNet qNetworkMock;
|
||||
|
||||
@Mock
|
||||
IOutputNeuralNet targetQNetworkMock;
|
||||
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
||||
|
||||
// Assemble
|
||||
MockDQN qNetwork = new MockDQN();
|
||||
MockDQN targetQNetwork = new MockDQN();
|
||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||
|
@ -30,7 +47,7 @@ public class StandardDQNTest {
|
|||
}
|
||||
};
|
||||
|
||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
||||
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
|
@ -45,10 +62,6 @@ public class StandardDQNTest {
|
|||
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
||||
|
||||
// Assemble
|
||||
MockDQN qNetwork = new MockDQN();
|
||||
MockDQN targetQNetwork = new MockDQN();
|
||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||
|
@ -56,7 +69,7 @@ public class StandardDQNTest {
|
|||
}
|
||||
};
|
||||
|
||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
||||
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
|
@ -71,10 +84,6 @@ public class StandardDQNTest {
|
|||
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
||||
|
||||
// Assemble
|
||||
MockDQN qNetwork = new MockDQN();
|
||||
MockDQN targetQNetwork = new MockDQN();
|
||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
||||
|
||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||
{
|
||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||
|
@ -86,7 +95,7 @@ public class StandardDQNTest {
|
|||
}
|
||||
};
|
||||
|
||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
||||
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||
|
||||
// Act
|
||||
DataSet result = sut.computeTDTargets(transitions);
|
||||
|
|
|
@ -1,26 +0,0 @@
|
|||
package org.deeplearning4j.rl4j.learning.sync.support;
|
||||
|
||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
||||
|
||||
public class MockTargetQNetworkSource implements TargetQNetworkSource {
|
||||
|
||||
|
||||
private final IDQN qNetwork;
|
||||
private final IDQN targetQNetwork;
|
||||
|
||||
public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) {
|
||||
this.qNetwork = qNetwork;
|
||||
this.targetQNetwork = targetQNetwork;
|
||||
}
|
||||
|
||||
@Override
|
||||
public IDQN getTargetQNetwork() {
|
||||
return targetQNetwork;
|
||||
}
|
||||
|
||||
@Override
|
||||
public IDQN getQNetwork() {
|
||||
return qNetwork;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue