commit
43fd64358c
|
@ -57,10 +57,8 @@ import org.deeplearning4j.nn.weights.WeightInit;
|
||||||
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
|
||||||
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
import org.deeplearning4j.util.ModelSerializer;
|
||||||
import org.junit.AfterClass;
|
import org.junit.*;
|
||||||
import org.junit.Before;
|
import org.junit.rules.TemporaryFolder;
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
@ -82,6 +80,7 @@ import org.nd4j.common.resources.Resources;
|
||||||
|
|
||||||
import java.io.ByteArrayInputStream;
|
import java.io.ByteArrayInputStream;
|
||||||
import java.io.ByteArrayOutputStream;
|
import java.io.ByteArrayOutputStream;
|
||||||
|
import java.io.File;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
@ -91,6 +90,9 @@ import static org.junit.Assert.*;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class TestComputationGraphNetwork extends BaseDL4JTest {
|
public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
private static ComputationGraphConfiguration getIrisGraphConfiguration() {
|
private static ComputationGraphConfiguration getIrisGraphConfiguration() {
|
||||||
return new NeuralNetConfiguration.Builder().seed(12345)
|
return new NeuralNetConfiguration.Builder().seed(12345)
|
||||||
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
|
.optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).graphBuilder()
|
||||||
|
@ -2177,4 +2179,40 @@ public class TestComputationGraphNetwork extends BaseDL4JTest {
|
||||||
INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2);
|
INDArray label = Nd4j.createFromArray(1, 0).reshape(1, 2);
|
||||||
cg.fit(new DataSet(in, label));
|
cg.fit(new DataSet(in, label));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMergeNchw() throws Exception {
|
||||||
|
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
|
.convolutionMode(ConvolutionMode.Same)
|
||||||
|
.graphBuilder()
|
||||||
|
.addInputs("in")
|
||||||
|
.layer("l0", new ConvolutionLayer.Builder()
|
||||||
|
.nOut(16)
|
||||||
|
.kernelSize(2,2).stride(1,1)
|
||||||
|
.build(), "in")
|
||||||
|
.layer("l1", new ConvolutionLayer.Builder()
|
||||||
|
.nOut(8)
|
||||||
|
.kernelSize(2,2).stride(1,1)
|
||||||
|
.build(), "in")
|
||||||
|
.addVertex("merge", new MergeVertex(), "l0", "l1")
|
||||||
|
.layer("out", new CnnLossLayer.Builder().activation(Activation.TANH).lossFunction(LossFunctions.LossFunction.MSE).build(), "merge")
|
||||||
|
.setOutputs("out")
|
||||||
|
.setInputTypes(InputType.convolutional(32, 32, 3, CNN2DFormat.NHWC))
|
||||||
|
.build();
|
||||||
|
|
||||||
|
ComputationGraph cg = new ComputationGraph(conf);
|
||||||
|
cg.init();
|
||||||
|
|
||||||
|
INDArray[] in = new INDArray[]{Nd4j.rand(DataType.FLOAT, 1, 32, 32, 3)};
|
||||||
|
INDArray out = cg.outputSingle(in);
|
||||||
|
|
||||||
|
File dir = testDir.newFolder();
|
||||||
|
File f = new File(dir, "net.zip");
|
||||||
|
cg.save(f);
|
||||||
|
|
||||||
|
ComputationGraph c2 = ComputationGraph.load(f, true);
|
||||||
|
INDArray out2 = c2.outputSingle(in);
|
||||||
|
|
||||||
|
assertEquals(out, out2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,8 +66,8 @@ import org.nd4j.shade.jackson.annotation.JsonProperty;
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
@JsonIgnoreProperties({"mask", "helper", "helperCountFail"})
|
@JsonIgnoreProperties({"mask", "helper", "helperCountFail", "initializedHelper"})
|
||||||
@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail"})
|
@EqualsAndHashCode(exclude = {"mask", "helper", "helperCountFail", "initializedHelper"})
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class Dropout implements IDropout {
|
public class Dropout implements IDropout {
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.nn.conf.graph;
|
package org.deeplearning4j.nn.conf.graph;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
import org.deeplearning4j.nn.conf.CNN2DFormat;
|
||||||
import org.deeplearning4j.nn.conf.RNNFormat;
|
import org.deeplearning4j.nn.conf.RNNFormat;
|
||||||
|
@ -38,6 +39,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
* -> [numExamples,depth1 + depth2,width,height]}<br>
|
* -> [numExamples,depth1 + depth2,width,height]}<br>
|
||||||
* @author Alex Black
|
* @author Alex Black
|
||||||
*/
|
*/
|
||||||
|
@Data
|
||||||
public class MergeVertex extends GraphVertex {
|
public class MergeVertex extends GraphVertex {
|
||||||
|
|
||||||
protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format
|
protected int mergeAxis = 1; //default value for backward compatibility (deserialization of old version JSON) - NCHW and NCW format
|
||||||
|
|
|
@ -141,7 +141,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
|
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
|
||||||
sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats());
|
sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats());
|
||||||
|
|
||||||
System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
// System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||||
File f = testDir.newFolder();
|
File f = testDir.newFolder();
|
||||||
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
|
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
|
||||||
int count = 0;
|
int count = 0;
|
||||||
|
@ -208,10 +208,10 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray paramsAfter = after.params();
|
INDArray paramsAfter = after.params();
|
||||||
System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
// System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||||
System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
// System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||||
System.out.println(Arrays.toString(
|
// System.out.println(Arrays.toString(
|
||||||
Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
// Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
|
||||||
assertNotEquals(paramsBefore, paramsAfter);
|
assertNotEquals(paramsBefore, paramsAfter);
|
||||||
|
|
||||||
|
|
||||||
|
@ -235,7 +235,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985
|
||||||
public void differentNetsTrainingTest() throws Exception {
|
public void differentNetsTrainingTest() throws Exception {
|
||||||
int batch = 3;
|
int batch = 3;
|
||||||
|
|
||||||
|
|
|
@ -131,6 +131,23 @@ if(NOT SD_CUDA)
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
#arm-compute entry
|
||||||
|
if(${HELPERS_armcompute})
|
||||||
|
find_package(ARMCOMPUTE REQUIRED)
|
||||||
|
|
||||||
|
if(ARMCOMPUTE_FOUND)
|
||||||
|
message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}")
|
||||||
|
set(HAVE_ARMCOMPUTE 1)
|
||||||
|
# Add preprocessor definition for ARM Compute NEON
|
||||||
|
add_definitions(-DARMCOMPUTENEON_ENABLED)
|
||||||
|
#build our library with neon support
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon")
|
||||||
|
include_directories(${ARMCOMPUTE_INCLUDE})
|
||||||
|
message("----${ARMCOMPUTE_INCLUDE}---")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
# new mkl-dnn entry
|
# new mkl-dnn entry
|
||||||
if (${HELPERS_mkldnn})
|
if (${HELPERS_mkldnn})
|
||||||
|
|
|
@ -146,6 +146,10 @@ if (HAVE_MKLDNN)
|
||||||
file(GLOB_RECURSE CUSTOMOPS_MKLDNN_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h)
|
file(GLOB_RECURSE CUSTOMOPS_MKLDNN_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(HAVE_ARMCOMPUTE)
|
||||||
|
file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/*.h)
|
||||||
|
endif()
|
||||||
|
|
||||||
if(SD_CUDA)
|
if(SD_CUDA)
|
||||||
message("Build cublas")
|
message("Build cublas")
|
||||||
find_package(CUDA)
|
find_package(CUDA)
|
||||||
|
@ -243,7 +247,7 @@ if(SD_CUDA)
|
||||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||||
${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
|
||||||
${CUSTOMOPS_GENERIC_SOURCES}
|
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||||
)
|
)
|
||||||
|
|
||||||
if (WIN32)
|
if (WIN32)
|
||||||
|
@ -351,8 +355,8 @@ elseif(SD_CPU)
|
||||||
add_definitions(-D__CPUBLAS__=true)
|
add_definitions(-D__CPUBLAS__=true)
|
||||||
add_library(samediff_obj OBJECT ${LEGACY_SOURCES}
|
add_library(samediff_obj OBJECT ${LEGACY_SOURCES}
|
||||||
${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES}
|
||||||
${OPS_SOURCES} ${PERF_SOURCES})
|
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
|
||||||
if(IOS)
|
if(IOS)
|
||||||
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
|
add_library(${SD_LIBRARY_NAME} STATIC $<TARGET_OBJECTS:samediff_obj>)
|
||||||
else()
|
else()
|
||||||
|
@ -378,12 +382,12 @@ elseif(SD_CPU)
|
||||||
if (NOT BLAS_LIBRARIES)
|
if (NOT BLAS_LIBRARIES)
|
||||||
set(BLAS_LIBRARIES "")
|
set(BLAS_LIBRARIES "")
|
||||||
endif()
|
endif()
|
||||||
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||||
|
|
||||||
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
|
if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}")
|
||||||
message(STATUS "Building minifier...")
|
message(STATUS "Building minifier...")
|
||||||
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
|
add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp)
|
||||||
target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES})
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9)
|
||||||
|
|
|
@ -0,0 +1,74 @@
|
||||||
|
################################################################################
|
||||||
|
# Copyright (c) 2020 Konduit K.K.
|
||||||
|
#
|
||||||
|
# This program and the accompanying materials are made available under the
|
||||||
|
# terms of the Apache License, Version 2.0 which is available at
|
||||||
|
# https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
#
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
################################################################################
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Find ARM COMPUTE LIBRARY STATIC libraries
|
||||||
|
|
||||||
|
SET (COMPUTE_INCLUDE_DIRS
|
||||||
|
/usr/include
|
||||||
|
${ARMCOMPUTE_ROOT}
|
||||||
|
${ARMCOMPUTE_ROOT}/include
|
||||||
|
${ARMCOMPUTE_ROOT}/applications
|
||||||
|
${ARMCOMPUTE_ROOT}/applications/arm_compute
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
SET (COMPUTE_LIB_DIRS
|
||||||
|
/lib
|
||||||
|
/usr/lib
|
||||||
|
${ARMCOMPUTE_ROOT}
|
||||||
|
${ARMCOMPUTE_ROOT}/lib
|
||||||
|
${ARMCOMPUTE_ROOT}/build
|
||||||
|
)
|
||||||
|
|
||||||
|
find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h
|
||||||
|
PATHS ${COMPUTE_INCLUDE_DIRS}
|
||||||
|
NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
|
||||||
|
|
||||||
|
find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h)
|
||||||
|
|
||||||
|
find_path(HALF_INCLUDE half/half.hpp)
|
||||||
|
find_path(HALF_INCLUDE half/half.hpp
|
||||||
|
PATHS ${ARMCOMPUTE_ROOT}/include
|
||||||
|
NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
|
||||||
|
include_directories(SYSTEM ${HALF_INCLUDE})
|
||||||
|
|
||||||
|
# Find the Arm Compute libraries if not already specified
|
||||||
|
if (NOT DEFINED ARMCOMPUTE_LIBRARIES)
|
||||||
|
|
||||||
|
find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static
|
||||||
|
PATHS ${COMPUTE_LIB_DIRS}
|
||||||
|
PATH_SUFFIXES "Release"
|
||||||
|
NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
|
||||||
|
|
||||||
|
find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static
|
||||||
|
PATHS ${COMPUTE_LIB_DIRS}
|
||||||
|
PATH_SUFFIXES "Release"
|
||||||
|
NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH)
|
||||||
|
# In case it wasn't there, try a default search (will work in cases where
|
||||||
|
# the library has been installed into a standard location)
|
||||||
|
find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static)
|
||||||
|
find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static)
|
||||||
|
|
||||||
|
set(ARMCOMPUTE_LIBRARIES ${ARMCOMPUTE_LIBRARY} ${ARMCOMPUTE_CORE_LIBRARY} )
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
INCLUDE(FindPackageHandleStandardArgs)
|
||||||
|
|
||||||
|
FIND_PACKAGE_HANDLE_STANDARD_ARGS(ARMCOMPUTE REQUIRED_VARS ARMCOMPUTE_INCLUDE ARMCOMPUTE_LIBRARIES)
|
||||||
|
|
|
@ -69,7 +69,7 @@ namespace cnpy {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ND4J_EXPORT npz_t : public std::unordered_map<std::string, NpyArray> {
|
struct ND4J_EXPORT npz_t : public std::map<std::string, NpyArray> {
|
||||||
void destruct() {
|
void destruct() {
|
||||||
npz_t::iterator it = this->begin();
|
npz_t::iterator it = this->begin();
|
||||||
for(; it != this->end(); ++it) (*it).second.destruct();
|
for(; it != this->end(); ++it) (*it).second.destruct();
|
||||||
|
|
|
@ -3,6 +3,8 @@
|
||||||
|
|
||||||
#cmakedefine HAVE_MKLDNN
|
#cmakedefine HAVE_MKLDNN
|
||||||
|
|
||||||
|
#cmakedefine HAVE_ARMCOMPUTE
|
||||||
|
|
||||||
#cmakedefine MKLDNN_PATH "@MKLDNN_PATH@"
|
#cmakedefine MKLDNN_PATH "@MKLDNN_PATH@"
|
||||||
|
|
||||||
#cmakedefine HAVE_OPENBLAS
|
#cmakedefine HAVE_OPENBLAS
|
||||||
|
|
|
@ -45,18 +45,18 @@ namespace sd {
|
||||||
DECLARE_TYPES(max_pool_with_argmax) {
|
DECLARE_TYPES(max_pool_with_argmax) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(sd::DataType::ANY)
|
->setAllowedInputTypes(sd::DataType::ANY)
|
||||||
->setAllowedOutputTypes(0, DataType::INHERIT)
|
->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS})
|
||||||
->setAllowedOutputTypes(1, {ALL_INTS});
|
->setAllowedOutputTypes(1, {ALL_INDICES});
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(max_pool_with_argmax) {
|
DECLARE_SHAPE_FN(max_pool_with_argmax) {
|
||||||
|
auto in = inputShape->at(0);
|
||||||
|
auto dtype = block.numD() ? D_ARG(0) : sd::DataType::INT64;
|
||||||
|
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
|
||||||
|
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, dtype));
|
||||||
|
|
||||||
auto in = inputShape->at(0);
|
return SHAPELIST(valuesShape, indicesShape);
|
||||||
auto valuesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in));
|
|
||||||
auto indicesShape = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(in, DataType::INT64));
|
|
||||||
|
|
||||||
return SHAPELIST(valuesShape, indicesShape);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -215,7 +215,9 @@ namespace helpers {
|
||||||
auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]);
|
auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]);
|
||||||
auto result = -1;
|
auto result = -1;
|
||||||
//auto loop = PRAGMA_THREADS_FOR {
|
//auto loop = PRAGMA_THREADS_FOR {
|
||||||
auto start = column, stop = rowNum, increment = 1;
|
auto start = column;
|
||||||
|
auto stop = rowNum;
|
||||||
|
auto increment = 1;
|
||||||
for (auto rowCounter = start; rowCounter < stop; rowCounter++) {
|
for (auto rowCounter = start; rowCounter < stop; rowCounter++) {
|
||||||
Nd4jLong xPos[] = {rowCounter, column};
|
Nd4jLong xPos[] = {rowCounter, column};
|
||||||
auto xIndex = shape::getOffset(compoundShape, xPos, 0);
|
auto xIndex = shape::getOffset(compoundShape, xPos, 0);
|
||||||
|
|
|
@ -73,7 +73,7 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
|
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
|
||||||
BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(input->dataType(), maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,8 @@
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||||
//
|
// implementation is based on following article:
|
||||||
|
// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,96 +32,167 @@ namespace ops {
|
||||||
namespace helpers {
|
namespace helpers {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// Fisher-Yates shuffle
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) {
|
||||||
|
|
||||||
|
for(Nd4jLong i = len-1; i > 0; --i) {
|
||||||
|
const Nd4jLong j = rng.relativeLong(ind++) % (i + 1);
|
||||||
|
if(i != j)
|
||||||
|
math::nd4j_swap<T>(buff[i*ews], buff[j*ews]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// mutual shuffle of two adjacent already shuffled ranges with length len1 and (totLen - len1) correspondingly
|
||||||
|
template <typename T>
|
||||||
|
static void mergeShuffle(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len1, const Nd4jLong& totLen, const Nd4jLong& ews, Nd4jLong ind) {
|
||||||
|
|
||||||
|
Nd4jLong beg = 0; // beginning
|
||||||
|
Nd4jLong mid = len1; // middle
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if(rng.relativeLong(ind++) % 2) {
|
||||||
|
if(mid == totLen)
|
||||||
|
break;
|
||||||
|
math::nd4j_swap<T>(buff[ews * beg], buff[ews * mid++]);
|
||||||
|
} else {
|
||||||
|
if(beg == mid)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++beg;
|
||||||
|
}
|
||||||
|
|
||||||
|
// fisherYates
|
||||||
|
while (beg < totLen) {
|
||||||
|
const Nd4jLong j = rng.relativeLong(ind++) % (beg + 1);
|
||||||
|
if(beg != j)
|
||||||
|
math::nd4j_swap<T>(buff[ews * beg], buff[ews * j]);
|
||||||
|
++beg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||||
|
|
||||||
// check edge cases first
|
|
||||||
int temp;
|
|
||||||
const int firstDim = input.sizeAt(0);
|
const int firstDim = input.sizeAt(0);
|
||||||
|
int temp;
|
||||||
|
|
||||||
if(input.lengthOf() == 1 || firstDim == 1) {
|
if(input.lengthOf() == 1 || firstDim == 1) {
|
||||||
|
|
||||||
if(!isInplace)
|
if(!isInplace)
|
||||||
output.assign(input);
|
output.assign(input);
|
||||||
}
|
}
|
||||||
else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) {
|
else if (shape::isCommonVector(input.shapeInfo(), temp)) {
|
||||||
|
|
||||||
// apply Fisher-Yates shuffle
|
NDArray* arr = &input;
|
||||||
if(isInplace) {
|
|
||||||
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
|
if (!isInplace) {
|
||||||
for(int i = firstDim-1; i > 0; --i) {
|
output.assign(input);
|
||||||
int r = rng.relativeInt(i) % i;
|
arr = &output;
|
||||||
if(i == r)
|
|
||||||
continue;
|
|
||||||
T t0 = input.t<T>(i);
|
|
||||||
T t1 = input.t<T>(r);
|
|
||||||
//math::nd4j_swap<T>(input(i), input(r));
|
|
||||||
input.r<T>(i) = t1;
|
|
||||||
input.r<T>(r) = t0;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
else {
|
|
||||||
std::vector<int> indices(firstDim);
|
|
||||||
std::iota(indices.begin(), indices.end(), 0);
|
|
||||||
output.p<T>(Nd4jLong(0), input.e<T>(0));
|
|
||||||
|
|
||||||
// FIXME: parallelism!!
|
const Nd4jLong ews = arr->ews();
|
||||||
for(int i = firstDim-1; i > 0; --i) {
|
|
||||||
int r = rng.relativeInt(i) % i;
|
|
||||||
output.r<T>(i) = input.t<T>(indices[r]);
|
|
||||||
if(i == r)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
output.r<T>(r) = input.t<T>(indices[i]);
|
const Nd4jLong len = arr->lengthOf();
|
||||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
const Nd4jLong threshold = 1<<22; // this number was deduced from diagram in article
|
||||||
|
|
||||||
|
int power = 0;
|
||||||
|
while ((len >> power) > threshold)
|
||||||
|
++power;
|
||||||
|
|
||||||
|
const Nd4jLong numChunks = 1 << power;
|
||||||
|
|
||||||
|
auto funcFisherYates = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; ++i) {
|
||||||
|
|
||||||
|
Nd4jLong offset = (len * i) >> power;
|
||||||
|
Nd4jLong currLen = ((len * (i + 1)) >> power) - offset;
|
||||||
|
fisherYates<T>(rng, arr->bufferAsT<T>() + offset*ews, currLen, ews, offset);
|
||||||
}
|
}
|
||||||
rng.rewindH(firstDim-1);
|
};
|
||||||
}
|
|
||||||
|
auto funcMerge = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
for (int64_t i = start, k = 1; i < stop; i += increment, ++k) {
|
||||||
|
Nd4jLong offset = len * i >> power;
|
||||||
|
Nd4jLong len1 = (len * (i + increment/2) >> power) - offset;
|
||||||
|
Nd4jLong totLen = (len * (i + increment) >> power) - offset;
|
||||||
|
mergeShuffle<T>(rng, arr->bufferAsT<T>() + offset*ews, len1, totLen, ews, len * k + offset);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(funcFisherYates, 0, numChunks);
|
||||||
|
|
||||||
|
for (int j = 1; j < numChunks; j += j)
|
||||||
|
samediff::Threads::parallel_for(funcMerge, 0, numChunks, 2*j);
|
||||||
|
|
||||||
|
// #pragma omp parallel for
|
||||||
|
// for (uint i = 0; i < numChunks; ++i) {
|
||||||
|
|
||||||
|
// Nd4jLong offset = (len * i) >> power;
|
||||||
|
// Nd4jLong currLen = ((len * (i + 1)) >> power) - offset;
|
||||||
|
// fisherYates<T>(rng, arr->bufferAsT<T>() + offset*ews, currLen, ews, offset);
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for (uint j = 1; j < numChunks; j += j) {
|
||||||
|
// #pragma omp parallel for
|
||||||
|
// for (auto i = 0; i < numChunks; i += 2*j) {
|
||||||
|
// Nd4jLong offset = len * i >> power;
|
||||||
|
// Nd4jLong len1 = (len * (i + j) >> power) - offset;
|
||||||
|
// Nd4jLong totLen = (len * (i + 2*j) >> power) - offset;
|
||||||
|
// mergeShuffle(rng, arr->bufferAsT<T>() + offset*ews, len1, totLen, ews, len * j + offset);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
rng.rewindH((len + 1) * power);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
|
||||||
// evaluate sub-arrays list of input array through all dimensions excluding first one
|
auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
|
||||||
auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
|
|
||||||
|
|
||||||
// apply Fisher-Yates shuffle
|
|
||||||
if(isInplace) {
|
if(isInplace) {
|
||||||
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().elementwiseThreshold())
|
|
||||||
for(int i = firstDim - 1; i > 0; --i) {
|
|
||||||
int r = rng.relativeInt(i) % i;
|
|
||||||
|
|
||||||
if(i == r)
|
auto subArrsList = input.allTensorsAlongDimension(dimsToExclude);
|
||||||
continue;
|
|
||||||
subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r));
|
// Fisher-Yates shuffle
|
||||||
|
for(int i = firstDim - 1; i > 0; --i) {
|
||||||
|
const int j = rng.relativeInt(i) % (i + 1);
|
||||||
|
if(i != j)
|
||||||
|
subArrsList.at(i)->swapUnsafe(*subArrsList.at(j));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
// evaluate sub-arrays list of output array through all dimensions excluding first one
|
|
||||||
auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
|
auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude);
|
||||||
|
auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude);
|
||||||
|
|
||||||
std::vector<int> indices(firstDim);
|
std::vector<int> indices(firstDim);
|
||||||
std::iota(indices.begin(), indices.end(), 0);
|
std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1
|
||||||
bool isZeroShuffled = false;
|
|
||||||
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
|
// shuffle indices
|
||||||
for(int i = firstDim - 1; i > 0; --i) {
|
fisherYates<int>(rng, indices.data(), firstDim, 1, 0);
|
||||||
int r = rng.relativeInt(i) % i;
|
|
||||||
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r]));
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
if(r == 0)
|
|
||||||
isZeroShuffled = true;
|
for (auto i = start; i < stop; ++i)
|
||||||
if(i == r)
|
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i]));
|
||||||
continue;
|
};
|
||||||
subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i]));
|
|
||||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
samediff::Threads::parallel_for(func, 0, firstDim);
|
||||||
}
|
|
||||||
if(!isZeroShuffled)
|
|
||||||
subArrsListOut.at(0)->assign(subArrsListIn.at(0));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rng.rewindH(firstDim-1);
|
rng.rewindH(firstDim-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
|
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -53,7 +53,7 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, const
|
||||||
|
|
||||||
int coords[MAX_RANK];
|
int coords[MAX_RANK];
|
||||||
|
|
||||||
for (uint64_t i = tid; i < zLen; i += totalThreads) {
|
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||||
shape::index2coords(i, zShapeInfo, coords);
|
shape::index2coords(i, zShapeInfo, coords);
|
||||||
|
|
||||||
const auto zOffset = shape::getOffset(zShapeInfo, coords);
|
const auto zOffset = shape::getOffset(zShapeInfo, coords);
|
||||||
|
@ -162,9 +162,9 @@ void concat(sd::LaunchContext * context, const std::vector<const NDArray*>& inAr
|
||||||
// }
|
// }
|
||||||
// else { // general (slower) case
|
// else { // general (slower) case
|
||||||
|
|
||||||
const int threadsPerBlock = 256;
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
const int blocksPerGrid = 512;
|
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
const int sharedMem = 512;
|
const int sharedMem = 256;
|
||||||
|
|
||||||
// prepare arrays of pointers on buffers and shapes
|
// prepare arrays of pointers on buffers and shapes
|
||||||
std::vector<const void*> hInBuffers(numOfInArrs);
|
std::vector<const void*> hInBuffers(numOfInArrs);
|
||||||
|
|
|
@ -88,7 +88,7 @@ namespace helpers {
|
||||||
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
|
void maxPoolingFunctor(sd::LaunchContext * context, sd::graph::Context& block, NDArray* input, NDArray* values, std::vector<int> const& params, NDArray* indices) {
|
||||||
NDArray::prepareSpecialUse({values, indices}, {input});
|
NDArray::prepareSpecialUse({values, indices}, {input});
|
||||||
auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType();
|
auto yType = indices == nullptr ? sd::DataType::INT64 : indices->dataType();
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), FLOAT_TYPES, INDEXING_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), yType, maxPoolingFunctor_, (block, input, values, params, indices), LIBND4J_TYPES, INDEXING_TYPES);
|
||||||
NDArray::registerSpecialUse({values, indices}, {input});
|
NDArray::registerSpecialUse({values, indices}, {input});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,228 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
// implemented algorithm is GPU adaptation of algorithm described in following article:
|
||||||
|
// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167
|
||||||
|
//
|
||||||
|
|
||||||
|
#include<ops/declarable/helpers/transforms.h>
|
||||||
|
#include <array/ResultSet.h>
|
||||||
|
#include <numeric>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <helpers/ShapeUtils.h>
|
||||||
|
#include <helpers/PointersManager.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void fisherYatesCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power) {
|
||||||
|
|
||||||
|
T* x = reinterpret_cast<T*>(vx);
|
||||||
|
|
||||||
|
__shared__ T* shmem, temp;
|
||||||
|
__shared__ Nd4jLong ind, blockOffset, lenPerBlock;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
extern __shared__ unsigned char sharedMemory[];
|
||||||
|
shmem = reinterpret_cast<T*>(sharedMemory);
|
||||||
|
|
||||||
|
blockOffset = (len * blockIdx.x) >> power;
|
||||||
|
lenPerBlock = ((len * (blockIdx.x + 1)) >> power) - blockOffset;
|
||||||
|
ind = blockOffset;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// copy from global memory to shared memory
|
||||||
|
if(threadIdx.x < lenPerBlock)
|
||||||
|
shmem[threadIdx.x] = x[(blockOffset + threadIdx.x) * ews];
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// *** apply Fisher-Yates shuffle to lenPerBlock number of elements
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
for(Nd4jLong i = lenPerBlock - 1; i > 0; --i) {
|
||||||
|
const Nd4jLong j = rng->relativeLong(ind++) % (i + 1);
|
||||||
|
if(i != j) {
|
||||||
|
temp = shmem[i];
|
||||||
|
shmem[i] = shmem[j];
|
||||||
|
shmem[j] = temp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// copy from shared memory to global memory
|
||||||
|
if(threadIdx.x < lenPerBlock)
|
||||||
|
x[(blockOffset + threadIdx.x) * ews] = shmem[threadIdx.x];
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void mergeShuffleCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power, const Nd4jLong iterNum) {
|
||||||
|
|
||||||
|
|
||||||
|
T* x = reinterpret_cast<T*>(vx);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong ind, blockOffset, factor, beg, mid, totLen, iterExp;
|
||||||
|
|
||||||
|
// *** apply mergeShuffle algorithm
|
||||||
|
if(threadIdx.x == 0) {
|
||||||
|
|
||||||
|
factor = blockIdx.x << iterNum;
|
||||||
|
iterExp = 1 << (iterNum - 1);
|
||||||
|
blockOffset = (len * factor) >> power;
|
||||||
|
mid = ((len * (factor + iterExp)) >> power) - blockOffset; // middle
|
||||||
|
totLen = ((len * (factor + 2*iterExp)) >> power) - blockOffset;
|
||||||
|
ind = iterNum * len + blockOffset;
|
||||||
|
beg = 0; // beginning
|
||||||
|
|
||||||
|
// printf("m %lld, blockIdx.x %lld, factor %lld, blockOffset %lld, mid %lld, totLen %lld \n", m,k,factor,blockOffset,mid,totLen);
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if(rng->relativeLong(ind++) % 2) {
|
||||||
|
if(mid == totLen)
|
||||||
|
break;
|
||||||
|
math::nd4j_swap<T>(x[(blockOffset + beg) * ews], x[(blockOffset + mid++) * ews]);
|
||||||
|
} else {
|
||||||
|
if(beg == mid)
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++beg;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fisher-Yates
|
||||||
|
while (beg < totLen) {
|
||||||
|
const Nd4jLong e = rng->relativeLong(ind++) % (beg + 1);
|
||||||
|
if(beg != e)
|
||||||
|
math::nd4j_swap<T>(x[(blockOffset + beg) * ews], x[(blockOffset + e) * ews]);
|
||||||
|
++beg;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// Fisher-Yates shuffle
|
||||||
|
template <typename T>
|
||||||
|
static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) {
|
||||||
|
|
||||||
|
for(Nd4jLong i = len-1; i > 0; --i) {
|
||||||
|
const Nd4jLong j = rng.relativeLong(ind++) % (i + 1);
|
||||||
|
if(i != j)
|
||||||
|
math::nd4j_swap<T>(buff[i*ews], buff[j*ews]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template <typename T>
|
||||||
|
static void randomShuffle_(sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||||
|
|
||||||
|
const int firstDim = input.sizeAt(0);
|
||||||
|
int temp;
|
||||||
|
|
||||||
|
if(input.lengthOf() == 1 || firstDim == 1) {
|
||||||
|
|
||||||
|
if(!isInplace)
|
||||||
|
output.assign(input);
|
||||||
|
}
|
||||||
|
else if (shape::isCommonVector(input.shapeInfo(), temp)) {
|
||||||
|
|
||||||
|
NDArray* arr = &input;
|
||||||
|
|
||||||
|
if (!isInplace) {
|
||||||
|
output.assign(input);
|
||||||
|
arr = &output;
|
||||||
|
}
|
||||||
|
|
||||||
|
const Nd4jLong len = arr->lengthOf();
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS;
|
||||||
|
|
||||||
|
int power = 0;
|
||||||
|
while ((len >> power) > threadsPerBlock)
|
||||||
|
++power;
|
||||||
|
|
||||||
|
const int blocksPerGrid = 1 << power;
|
||||||
|
const int sharedMem = threadsPerBlock * input.sizeOfT() + 256;
|
||||||
|
|
||||||
|
PointersManager manager(context, "NDArray::randomShuffle cuda");
|
||||||
|
|
||||||
|
sd::graph::RandomGenerator* pRng = reinterpret_cast<sd::graph::RandomGenerator*>(manager.replicatePointer(&rng, sizeof(sd::graph::RandomGenerator)));
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({arr}, {arr});
|
||||||
|
fisherYatesCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *context->getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power);
|
||||||
|
for (Nd4jLong j = 1, i = 1; j < blocksPerGrid; j += j, ++i)
|
||||||
|
mergeShuffleCuda<T><<<blocksPerGrid/(2*j), threadsPerBlock, 256, *context->getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power, i);
|
||||||
|
NDArray::registerSpecialUse({arr}, {arr});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
|
||||||
|
rng.rewindH((len + 1) * power);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||||
|
|
||||||
|
if(isInplace) {
|
||||||
|
|
||||||
|
auto subArrsList = input.allTensorsAlongDimension(dimsToExclude);
|
||||||
|
|
||||||
|
// Fisher-Yates shuffle
|
||||||
|
for(int i = firstDim - 1; i > 0; --i) {
|
||||||
|
const int j = rng.relativeInt(i) % (i + 1);
|
||||||
|
if(i != j)
|
||||||
|
subArrsList.at(i)->swapUnsafe(*subArrsList.at(j));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude);
|
||||||
|
auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude);
|
||||||
|
|
||||||
|
std::vector<int> indices(firstDim);
|
||||||
|
std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1
|
||||||
|
|
||||||
|
// shuffle indices
|
||||||
|
fisherYates<int>(rng, indices.data(), firstDim, 1, 0);
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
|
||||||
|
for (auto i = start; i < stop; ++i)
|
||||||
|
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i]));
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, firstDim);
|
||||||
|
}
|
||||||
|
|
||||||
|
rng.rewindH(firstDim-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/////////////////////////////////////////////////////////////////////////
|
||||||
|
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||||
|
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
// BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -300,129 +300,6 @@ void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray
|
||||||
manager.synchronize();
|
manager.synchronize();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) {
|
|
||||||
auto tid = blockIdx.x * blockDim.x;
|
|
||||||
auto step = blockDim.x * gridDim.x;
|
|
||||||
|
|
||||||
for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
|
|
||||||
int r = rng->relativeInt(i) % i;
|
|
||||||
if (i != r) {
|
|
||||||
const auto iOffset = shape::getIndexOffset(i, shape);
|
|
||||||
const auto rOffset = shape::getIndexOffset(r, shape);
|
|
||||||
T e0 = input[iOffset];
|
|
||||||
T e1 = input[rOffset];
|
|
||||||
//math::nd4j_swap<T>(input(i), input(r));
|
|
||||||
input[iOffset] = e1;
|
|
||||||
input[rOffset] = e0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
template <typename T>
|
|
||||||
static __global__ void fillShuffleKernel(T* input, Nd4jLong const* inputShape, T* output, Nd4jLong const* outputShape, Nd4jLong firstDim, int* indices, sd::graph::RandomGenerator* rng) {
|
|
||||||
|
|
||||||
// PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
|
|
||||||
auto tid = blockIdx.x * blockDim.x;
|
|
||||||
auto step = blockDim.x * gridDim.x;
|
|
||||||
|
|
||||||
for(int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
|
|
||||||
int r = rng->relativeInt(i) % i;
|
|
||||||
output[shape::getIndexOffset(i, outputShape)] = input[shape::getIndexOffset(indices[r], inputShape)];
|
|
||||||
if(i != r) {
|
|
||||||
output[shape::getIndexOffset(r, outputShape)] = input[shape::getIndexOffset(indices[i], inputShape)];
|
|
||||||
// output.p(r, input.e<T>(indices[i]));
|
|
||||||
// math::nd4j_swap<int>(indices[i], indices[r]);
|
|
||||||
atomicExch(&indices[i], indices[r]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
template <typename T>
|
|
||||||
void randomShuffle_(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
|
||||||
|
|
||||||
// check edge cases first
|
|
||||||
int temp;
|
|
||||||
const int firstDim = input.sizeAt(0);
|
|
||||||
auto stream = context->getCudaStream();
|
|
||||||
NDArray::prepareSpecialUse({&output}, {&input});
|
|
||||||
if(input.lengthOf() == 1 || firstDim == 1) {
|
|
||||||
if(!isInplace)
|
|
||||||
output.assign(input);
|
|
||||||
}
|
|
||||||
else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) {
|
|
||||||
|
|
||||||
// apply Fisher-Yates shuffle
|
|
||||||
sd::graph::RandomGenerator* dRandom = nullptr;
|
|
||||||
cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator));
|
|
||||||
cudaMemcpy(dRandom, &rng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice);
|
|
||||||
T* inputBuf = reinterpret_cast<T*>(input.specialBuffer());
|
|
||||||
if(isInplace) {
|
|
||||||
swapShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), firstDim, dRandom);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
std::vector<int> indices(firstDim);
|
|
||||||
std::iota(indices.begin(), indices.end(), 0);
|
|
||||||
cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), cudaMemcpyDeviceToDevice);
|
|
||||||
//output.p<T>(Nd4jLong(0), input.e<T>(0));
|
|
||||||
PointersManager pointersManager(context, "helper::randomShuffle_");
|
|
||||||
int* indicesDev = reinterpret_cast<int*>(pointersManager.replicatePointer(indices.data(), indices.size() * sizeof(int)));
|
|
||||||
T* outputBuf = reinterpret_cast<T*>(output.specialBuffer());
|
|
||||||
fillShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), outputBuf, output.specialShapeInfo(), firstDim, indicesDev, dRandom);
|
|
||||||
pointersManager.synchronize();
|
|
||||||
}
|
|
||||||
// rng.rewindH(firstDim - 1);
|
|
||||||
cudaFree(dRandom);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
|
|
||||||
// evaluate sub-arrays list of input array through all dimensions excluding first one
|
|
||||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
|
||||||
auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
|
|
||||||
|
|
||||||
// apply Fisher-Yates shuffle
|
|
||||||
if(isInplace) {
|
|
||||||
for(int i = firstDim - 1; i > 0; --i) {
|
|
||||||
int r = rng.relativeInt(i) % i;
|
|
||||||
|
|
||||||
if(i != r)
|
|
||||||
subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
// evaluate sub-arrays list of output array through all dimensions excluding first one
|
|
||||||
auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
|
|
||||||
std::vector<int> indices(firstDim);
|
|
||||||
std::iota(indices.begin(), indices.end(), 0);
|
|
||||||
bool isZeroShuffled = false;
|
|
||||||
|
|
||||||
for(int i = firstDim - 1; i > 0; --i) {
|
|
||||||
int r = rng.relativeInt(i) % i;
|
|
||||||
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r]));
|
|
||||||
if(r == 0)
|
|
||||||
isZeroShuffled = true;
|
|
||||||
|
|
||||||
if(i != r) {
|
|
||||||
subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i]));
|
|
||||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if(!isZeroShuffled)
|
|
||||||
subArrsListOut.at(0)->assign(subArrsListIn.at(0));
|
|
||||||
}
|
|
||||||
rng.rewindH(firstDim-1);
|
|
||||||
}
|
|
||||||
NDArray::registerSpecialUse({&output}, {&input});
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
|
||||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
|
|
||||||
}
|
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void eye(sd::LaunchContext * context, NDArray& output) {
|
void eye(sd::LaunchContext * context, NDArray& output) {
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,278 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
// Created by Abdelrauf 2020
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <system/platform_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
#include <cstdint>
|
||||||
|
#include <helpers/LoopsCoordsHelper.h>
|
||||||
|
|
||||||
|
#include "armcomputeUtils.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Arm_DataType getArmType ( const DataType &dType){
|
||||||
|
Arm_DataType ret;
|
||||||
|
switch (dType){
|
||||||
|
case HALF :
|
||||||
|
ret = Arm_DataType::F16;
|
||||||
|
break;
|
||||||
|
case FLOAT32 :
|
||||||
|
ret = Arm_DataType::F32;
|
||||||
|
break;
|
||||||
|
case DOUBLE :
|
||||||
|
ret = Arm_DataType::F64;
|
||||||
|
break;
|
||||||
|
case INT8 :
|
||||||
|
ret = Arm_DataType::S8;
|
||||||
|
break;
|
||||||
|
case INT16 :
|
||||||
|
ret = Arm_DataType::S16;
|
||||||
|
break;
|
||||||
|
case INT32 :
|
||||||
|
ret = Arm_DataType::S32;
|
||||||
|
break;
|
||||||
|
case INT64 :
|
||||||
|
ret = Arm_DataType::S64;
|
||||||
|
break;
|
||||||
|
case UINT8 :
|
||||||
|
ret = Arm_DataType::U8;
|
||||||
|
break;
|
||||||
|
case UINT16 :
|
||||||
|
ret = Arm_DataType::U16;
|
||||||
|
break;
|
||||||
|
case UINT32 :
|
||||||
|
ret = Arm_DataType::U32;
|
||||||
|
break;
|
||||||
|
case UINT64 :
|
||||||
|
ret = Arm_DataType::U64;
|
||||||
|
break;
|
||||||
|
case BFLOAT16 :
|
||||||
|
ret = Arm_DataType::BFLOAT16;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
ret = Arm_DataType::UNKNOWN;
|
||||||
|
};
|
||||||
|
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
bool isArmcomputeFriendly(const NDArray& arr) {
|
||||||
|
auto dType = getArmType(arr.dataType());
|
||||||
|
int rank = (int)(arr.rankOf());
|
||||||
|
return dType != Arm_DataType::UNKNOWN &&
|
||||||
|
rank<=arm_compute::MAX_DIMS &&
|
||||||
|
arr.ordering() == 'c' &&
|
||||||
|
arr.ews()==1 &&
|
||||||
|
shape::strideDescendingCAscendingF(arr.shapeInfo()) == true;
|
||||||
|
}
|
||||||
|
|
||||||
|
Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases,sd::DataType ndArrayType, arm_compute::DataLayout layout) {
|
||||||
|
constexpr int numChannels = 1;
|
||||||
|
auto dType = getArmType(ndArrayType);
|
||||||
|
|
||||||
|
Arm_TensorShape shape;
|
||||||
|
shape.set_num_dimensions(rank);
|
||||||
|
for (int i = 0, j = rank - 1; i < rank; i++, j--) {
|
||||||
|
shape[i] = static_cast<uint32_t>(bases[j]);
|
||||||
|
}
|
||||||
|
// fill the rest unused with 1
|
||||||
|
for (int i = rank; i < arm_compute::MAX_DIMS; i++) {
|
||||||
|
shape[i] = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
return Arm_TensorInfo(shape, numChannels, dType, layout);
|
||||||
|
}
|
||||||
|
|
||||||
|
Arm_TensorInfo getArmTensorInfo(const NDArray& arr,
|
||||||
|
arm_compute::DataLayout layout) {
|
||||||
|
auto dType = getArmType(arr.dataType());
|
||||||
|
|
||||||
|
//
|
||||||
|
constexpr int numChannels = 1;
|
||||||
|
int rank = (int)(arr.rankOf());
|
||||||
|
auto bases = arr.shapeOf();
|
||||||
|
auto arrStrides = arr.stridesOf();
|
||||||
|
|
||||||
|
// https://arm-software.github.io/ComputeLibrary/v20.05/_dimensions_8h_source.xhtml
|
||||||
|
// note: underhood it is stored as std::array<T, num_max_dimensions> _id;
|
||||||
|
// TensorShape is derived from Dimensions<uint32_t>
|
||||||
|
// as well as Strides : public Dimensions<uint32_t>
|
||||||
|
Arm_TensorShape shape;
|
||||||
|
Arm_Strides strides;
|
||||||
|
shape.set_num_dimensions(rank);
|
||||||
|
strides.set_num_dimensions(rank);
|
||||||
|
size_t element_size = arm_compute::data_size_from_type(dType);
|
||||||
|
for (int i = 0, j = rank - 1; i < rank; i++, j--) {
|
||||||
|
shape[i] = static_cast<uint32_t>(bases[j]);
|
||||||
|
strides[i] = static_cast<uint32_t>(arrStrides[j]) * element_size;
|
||||||
|
}
|
||||||
|
// fill the rest unused with 1
|
||||||
|
for (int i = rank; i < arm_compute::MAX_DIMS; i++) {
|
||||||
|
shape[i] = 1;
|
||||||
|
}
|
||||||
|
size_t total_size;
|
||||||
|
size_t size_ind = rank - 1;
|
||||||
|
total_size = shape[size_ind] * strides[size_ind];
|
||||||
|
|
||||||
|
Arm_TensorInfo info;
|
||||||
|
info.init(shape, numChannels, dType, strides, 0, total_size);
|
||||||
|
info.set_data_layout(layout);
|
||||||
|
|
||||||
|
return info;
|
||||||
|
}
|
||||||
|
|
||||||
|
Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout) {
|
||||||
|
// - Ownership of the backing memory is not transferred to the tensor itself.
|
||||||
|
// - The tensor mustn't be memory managed.
|
||||||
|
// - Padding requirements should be accounted by the client code.
|
||||||
|
// In other words, if padding is required by the tensor after the function
|
||||||
|
// configuration step, then the imported backing memory should account for it.
|
||||||
|
// Padding can be checked through the TensorInfo::padding() interface.
|
||||||
|
|
||||||
|
// Import existing pointer as backing memory
|
||||||
|
auto info = getArmTensorInfo(arr, layout);
|
||||||
|
Arm_Tensor tensor;
|
||||||
|
tensor.allocator()->init(info);
|
||||||
|
void* buff = (void*)arr.buffer();
|
||||||
|
tensor.allocator()->import_memory(buff);
|
||||||
|
return tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output) {
|
||||||
|
//only for C order
|
||||||
|
//only for C order
|
||||||
|
if (output.ordering() != 'c') return;
|
||||||
|
auto shapeInfo = output.shapeInfo();
|
||||||
|
auto bases = &(shapeInfo[1]);
|
||||||
|
Nd4jLong rank = shapeInfo[0];
|
||||||
|
auto strides = output.stridesOf();
|
||||||
|
int width = bases[rank - 1];
|
||||||
|
uint8_t* outputBuffer = (uint8_t*)output.buffer();
|
||||||
|
size_t offset = 0;
|
||||||
|
arm_compute::Window window;
|
||||||
|
arm_compute::Iterator tensor_it(&inTensor, window);
|
||||||
|
|
||||||
|
int element_size = inTensor.info()->element_size();
|
||||||
|
window.use_tensor_dimensions(inTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY);
|
||||||
|
|
||||||
|
// if (output.ews() == 1) {
|
||||||
|
auto copySize = width * element_size;
|
||||||
|
auto dest = outputBuffer;
|
||||||
|
arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
|
||||||
|
{
|
||||||
|
auto src = tensor_it.ptr();
|
||||||
|
memcpy(dest, src, copySize);
|
||||||
|
dest += copySize;
|
||||||
|
},
|
||||||
|
tensor_it);
|
||||||
|
// }
|
||||||
|
// else {
|
||||||
|
// Nd4jLong coords[MAX_RANK] = {};
|
||||||
|
// if(strides[rank-1]!=1){
|
||||||
|
// throw std::runtime_error( "not implemented for subarrays whose last stride is not 1");
|
||||||
|
// //TODO: implement to work with all subarrays properly
|
||||||
|
// }
|
||||||
|
// arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
|
||||||
|
// {
|
||||||
|
// auto src = tensor_it.ptr();
|
||||||
|
// auto dest = outputBuffer + offset * element_size;
|
||||||
|
// memcpy(dest, src, width * element_size);
|
||||||
|
// offset = sd::inc_coords(bases, strides, coords, offset, rank, 1);
|
||||||
|
// },
|
||||||
|
// tensor_it);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
void copyToTensor(const NDArray& input, Arm_Tensor& outTensor) {
|
||||||
|
//only for C order
|
||||||
|
if (input.ordering() != 'c') return;
|
||||||
|
auto shapeInfo = input.shapeInfo();
|
||||||
|
auto bases = &(shapeInfo[1]);
|
||||||
|
Nd4jLong rank = shapeInfo[0];
|
||||||
|
auto strides = input.stridesOf();
|
||||||
|
uint8_t *inputBuffer = (uint8_t*)input.buffer();
|
||||||
|
int width = bases[rank - 1];
|
||||||
|
size_t offset = 0;
|
||||||
|
arm_compute::Window window;
|
||||||
|
arm_compute::Iterator tensor_it(&outTensor, window);
|
||||||
|
int element_size = outTensor.info()->element_size();
|
||||||
|
|
||||||
|
window.use_tensor_dimensions(outTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY);
|
||||||
|
|
||||||
|
// if (input.ews() == 1) {
|
||||||
|
|
||||||
|
auto copySize = width * element_size;
|
||||||
|
auto src = inputBuffer;
|
||||||
|
arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
|
||||||
|
{
|
||||||
|
auto dest = tensor_it.ptr();
|
||||||
|
memcpy(dest,src, copySize);
|
||||||
|
src += copySize;
|
||||||
|
},
|
||||||
|
tensor_it);
|
||||||
|
// }
|
||||||
|
// else {
|
||||||
|
// Nd4jLong coords[MAX_RANK] = {};
|
||||||
|
// if(strides[rank-1]!=1){
|
||||||
|
// throw std::runtime_error( "not implemented for subarrays whose last stride is not 1");
|
||||||
|
// //TODO: implement to work with all subarrays properly
|
||||||
|
// }
|
||||||
|
// arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id)
|
||||||
|
// {
|
||||||
|
// auto dest = tensor_it.ptr();
|
||||||
|
// auto src = inputBuffer + offset * element_size;
|
||||||
|
// offset = sd::inc_coords(bases, strides, coords, offset, rank, 1);
|
||||||
|
// },
|
||||||
|
// tensor_it);
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// armcompute should be built with debug option
|
||||||
|
void print_tensor(Arm_ITensor& tensor, const char* msg) {
|
||||||
|
auto info = tensor.info();
|
||||||
|
auto padding = info->padding();
|
||||||
|
std::cout << msg << "\ntotal: " << info->total_size() << "\n";
|
||||||
|
|
||||||
|
for (int i = 0; i < arm_compute::MAX_DIMS; i++) {
|
||||||
|
std::cout << info->dimension(i) << ",";
|
||||||
|
}
|
||||||
|
std::cout << std::endl;
|
||||||
|
for (int i = 0; i < arm_compute::MAX_DIMS; i++) {
|
||||||
|
std::cout << info->strides_in_bytes()[i] << ",";
|
||||||
|
}
|
||||||
|
std::cout << "\npadding: l " << padding.left << ", r " << padding.right
|
||||||
|
<< ", t " << padding.top << ", b " << padding.bottom << std::endl;
|
||||||
|
|
||||||
|
#ifdef ARM_COMPUTE_ASSERTS_ENABLED
|
||||||
|
//note it did not print correctly fro NHWC
|
||||||
|
std::cout << msg << ":\n";
|
||||||
|
tensor.print(std::cout);
|
||||||
|
std::cout << std::endl;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,133 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
#ifndef DEV_TESTSARMCOMPUTEUTILS_H
|
||||||
|
#define DEV_TESTSARMCOMPUTEUTILS_H
|
||||||
|
|
||||||
|
|
||||||
|
#include <legacy/NativeOps.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
#include <graph/Context.h>
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <system/platform_boilerplate.h>
|
||||||
|
#include <arm_compute/runtime/NEON/NEFunctions.h>
|
||||||
|
#include <arm_compute/core/Types.h>
|
||||||
|
#include <arm_compute/core/TensorInfo.h>
|
||||||
|
#include <arm_compute/core/TensorShape.h>
|
||||||
|
#include <arm_compute/core/Strides.h>
|
||||||
|
#include <arm_compute/core/Helpers.h>
|
||||||
|
#include <arm_compute/core/ITensor.h>
|
||||||
|
#include <arm_compute/core/Types.h>
|
||||||
|
#include <arm_compute/core/Validate.h>
|
||||||
|
#include <arm_compute/core/Window.h>
|
||||||
|
#include <arm_compute/runtime/Tensor.h>
|
||||||
|
#include <arm_compute/runtime/TensorAllocator.h>
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
using namespace samediff;
|
||||||
|
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
using Arm_DataType = arm_compute::DataType;
|
||||||
|
using Arm_Tensor = arm_compute::Tensor;
|
||||||
|
using Arm_ITensor = arm_compute::ITensor;
|
||||||
|
using Arm_TensorInfo = arm_compute::TensorInfo;
|
||||||
|
using Arm_TensorShape = arm_compute::TensorShape;
|
||||||
|
using Arm_Strides = arm_compute::Strides;
|
||||||
|
/**
|
||||||
|
* Here we actually declare our platform helpers
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(maxpool2d, ENGINE_CPU);
|
||||||
|
|
||||||
|
DECLARE_PLATFORM(avgpool2d, ENGINE_CPU);
|
||||||
|
|
||||||
|
//utils
|
||||||
|
Arm_DataType getArmType(const sd::DataType& dType);
|
||||||
|
|
||||||
|
Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases, sd::DataType ndArrayType, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN);
|
||||||
|
|
||||||
|
Arm_TensorInfo getArmTensorInfo(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN);
|
||||||
|
|
||||||
|
Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN);
|
||||||
|
|
||||||
|
void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output);
|
||||||
|
void copyToTensor(const NDArray& input, Arm_Tensor& outTensor);
|
||||||
|
void print_tensor(Arm_ITensor& tensor, const char* msg);
|
||||||
|
bool isArmcomputeFriendly(const NDArray& arr);
|
||||||
|
|
||||||
|
|
||||||
|
template<typename F>
|
||||||
|
class ArmFunction {
|
||||||
|
public:
|
||||||
|
|
||||||
|
template<typename ...Args>
|
||||||
|
void configure(NDArray *input , NDArray *output, arm_compute::DataLayout layout, Args&& ...args) {
|
||||||
|
|
||||||
|
auto inInfo = getArmTensorInfo(*input, layout);
|
||||||
|
auto outInfo = getArmTensorInfo(*output, layout);
|
||||||
|
in.allocator()->init(inInfo);
|
||||||
|
out.allocator()->init(outInfo);
|
||||||
|
armFunction.configure(&in,&out,std::forward<Args>(args) ...);
|
||||||
|
if (in.info()->has_padding()) {
|
||||||
|
//allocate and copy
|
||||||
|
in.allocator()->allocate();
|
||||||
|
//copy
|
||||||
|
copyToTensor(*input, in);
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
//import buffer
|
||||||
|
void* buff = input->buffer();
|
||||||
|
in.allocator()->import_memory(buff);
|
||||||
|
}
|
||||||
|
if (out.info()->has_padding()) {
|
||||||
|
//store pointer to our array to copy after run
|
||||||
|
out.allocator()->allocate();
|
||||||
|
outNd = output;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
//import
|
||||||
|
void* buff = output->buffer();
|
||||||
|
out.allocator()->import_memory(buff);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void run() {
|
||||||
|
armFunction.run();
|
||||||
|
if (outNd) {
|
||||||
|
copyFromTensor(out, *outNd);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Arm_Tensor in;
|
||||||
|
Arm_Tensor out;
|
||||||
|
NDArray *outNd=nullptr;
|
||||||
|
F armFunction{};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#endif //DEV_TESTSARMCOMPUTEUTILS_H
|
|
@ -0,0 +1,106 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
// Created by Abdelrauf (rauf@konduit.ai) 2020
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <system/platform_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
|
||||||
|
#include "armcomputeUtils.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
|
|
||||||
|
const auto kH = INT_ARG(0);
|
||||||
|
const auto kW = INT_ARG(1);
|
||||||
|
const auto sH = INT_ARG(2);
|
||||||
|
const auto sW = INT_ARG(3);
|
||||||
|
auto pH = INT_ARG(4);
|
||||||
|
auto pW = INT_ARG(5);
|
||||||
|
const auto dH = INT_ARG(6);
|
||||||
|
const auto dW = INT_ARG(7);
|
||||||
|
const auto paddingMode = INT_ARG(8);
|
||||||
|
const auto extraParam0 = INT_ARG(9);
|
||||||
|
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D ARMCOMPUTE op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D ARMCOMPUTE op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
|
bool exclude_padding= (extraParam0 == 0) ? true : false;
|
||||||
|
|
||||||
|
auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC;
|
||||||
|
|
||||||
|
// Calculate individual paddings
|
||||||
|
unsigned int pad_left, pad_top, pad_right, pad_bottom;
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
if(paddingMode){
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
}
|
||||||
|
pad_left = pW;
|
||||||
|
pad_top = pH;
|
||||||
|
pad_right = (oW - 1) * sW - iW + kW - pW ;
|
||||||
|
pad_bottom = (oH - 1) * sH - iH + kH - pH ;
|
||||||
|
|
||||||
|
#if 0
|
||||||
|
nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH
|
||||||
|
, pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0);
|
||||||
|
#endif
|
||||||
|
auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR);
|
||||||
|
auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::AVG, arm_compute::Size2D(kW, kH), dataLayout, poolPad, exclude_padding);
|
||||||
|
ArmFunction<arm_compute::NEPoolingLayer> pool;
|
||||||
|
pool.configure(input,output, dataLayout, poolInfo);
|
||||||
|
|
||||||
|
pool.run(); // run function
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
const int dH = INT_ARG(6);
|
||||||
|
const int dW = INT_ARG(7);
|
||||||
|
// Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32
|
||||||
|
auto dTypeInput = getArmType(input->dataType());
|
||||||
|
auto dTypeOutput = getArmType(output->dataType());
|
||||||
|
bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output)
|
||||||
|
&& (dTypeInput ==Arm_DataType::F32)
|
||||||
|
&& (dTypeOutput ==Arm_DataType::F32);
|
||||||
|
return is_supported;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,106 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
// Created by Abdelrauf 2020
|
||||||
|
|
||||||
|
|
||||||
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
#include <ops/declarable/OpRegistrator.h>
|
||||||
|
#include <system/platform_boilerplate.h>
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
|
|
||||||
|
|
||||||
|
#include "armcomputeUtils.h"
|
||||||
|
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace platforms {
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D ARMCOMPUTE OP: input array should have rank of 4, but got %i instead", input->rankOf());
|
||||||
|
|
||||||
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
|
const int kH = INT_ARG(0);
|
||||||
|
const int kW = INT_ARG(1);
|
||||||
|
const int sH = INT_ARG(2);
|
||||||
|
const int sW = INT_ARG(3);
|
||||||
|
int pH = INT_ARG(4);
|
||||||
|
int pW = INT_ARG(5);
|
||||||
|
const int dH = INT_ARG(6);
|
||||||
|
const int dW = INT_ARG(7);
|
||||||
|
const int paddingMode = INT_ARG(8);
|
||||||
|
// const int extraParam0 = INT_ARG(9);
|
||||||
|
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
|
auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC;
|
||||||
|
|
||||||
|
// Calculate individual paddings
|
||||||
|
unsigned int pad_left, pad_top, pad_right, pad_bottom;
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
if(paddingMode){
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
}
|
||||||
|
pad_left = pW;
|
||||||
|
pad_top = pH;
|
||||||
|
pad_right = (oW - 1) * sW - iW + kW - pW ;
|
||||||
|
pad_bottom = (oH - 1) * sH - iH + kH - pH ;
|
||||||
|
#if 0
|
||||||
|
nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH
|
||||||
|
, pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR);
|
||||||
|
auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::MAX, arm_compute::Size2D(kW, kH), dataLayout, poolPad);
|
||||||
|
ArmFunction<arm_compute::NEPoolingLayer> pool;
|
||||||
|
|
||||||
|
pool.configure(input,output, dataLayout, poolInfo);
|
||||||
|
|
||||||
|
pool.run(); // run function
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(maxpool2d, ENGINE_CPU) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
const int dH = INT_ARG(6);
|
||||||
|
const int dW = INT_ARG(7);
|
||||||
|
// Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32
|
||||||
|
auto dTypeInput = getArmType(input->dataType());
|
||||||
|
auto dTypeOutput = getArmType(output->dataType());
|
||||||
|
bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output)
|
||||||
|
&& (dTypeInput ==Arm_DataType::F32)
|
||||||
|
&& (dTypeOutput ==Arm_DataType::F32);
|
||||||
|
return is_supported;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -3963,9 +3963,6 @@ namespace simdOps {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifndef __clang__
|
|
||||||
#pragma omp declare simd uniform(extraParamsRef)
|
|
||||||
#endif
|
|
||||||
op_def static Y merge(X old, X opOutput, X *extraParamsRef) {
|
op_def static Y merge(X old, X opOutput, X *extraParamsRef) {
|
||||||
return update(old, opOutput, extraParamsRef);
|
return update(old, opOutput, extraParamsRef);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,185 @@
|
||||||
|
#!/bin/bash
|
||||||
|
TARGET=armv7-a
|
||||||
|
BLAS_TARGET_NAME=ARMV7
|
||||||
|
ARMCOMPUTE_TARGET=armv7a
|
||||||
|
#BASE_DIR=${HOME}/pi
|
||||||
|
#https://stackoverflow.com/questions/59895/how-to-get-the-source-directory-of-a-bash-script-from-within-the-script-itself
|
||||||
|
SOURCE="${BASH_SOURCE[0]}"
|
||||||
|
ARMCOMPUTE_DEBUG=1
|
||||||
|
LIBND4J_BUILD_MODE=Release
|
||||||
|
while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink
|
||||||
|
DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
SOURCE="$(readlink "$SOURCE")"
|
||||||
|
[[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located
|
||||||
|
done
|
||||||
|
BASE_DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )"
|
||||||
|
CMAKE=cmake #/snap/bin/cmake
|
||||||
|
|
||||||
|
mkdir -p ${BASE_DIR}/helper_bin/
|
||||||
|
|
||||||
|
CROSS_COMPILER_URL=https://sourceforge.net/projects/raspberry-pi-cross-compilers/files/Raspberry%20Pi%20GCC%20Cross-Compiler%20Toolchains/Buster/GCC%208.3.0/Raspberry%20Pi%203A%2B%2C%203B%2B%2C%204/cross-gcc-8.3.0-pi_3%2B.tar.gz/download
|
||||||
|
CROSS_COMPILER_DIR=${BASE_DIR}/helper_bin/cross_compiler
|
||||||
|
|
||||||
|
SCONS_LOCAL_URL=http://prdownloads.sourceforge.net/scons/scons-local-3.1.1.tar.gz
|
||||||
|
SCONS_LOCAL_DIR=${BASE_DIR}/helper_bin/scons_local
|
||||||
|
|
||||||
|
THIRD_PARTY=${BASE_DIR}/third_party_libs
|
||||||
|
|
||||||
|
ARMCOMPUTE_GIT_URL=https://github.com/ARM-software/ComputeLibrary.git
|
||||||
|
ARMCOMPUTE_TAG=v20.05
|
||||||
|
ARMCOMPUTE_DIR=${THIRD_PARTY}/arm_compute_dir
|
||||||
|
|
||||||
|
OPENBLAS_GIT_URL="https://github.com/xianyi/OpenBLAS.git"
|
||||||
|
OPENBLAS_DIR=${THIRD_PARTY}/OpenBLAS
|
||||||
|
|
||||||
|
|
||||||
|
LIBND4J_SRC_DIR=${BASE_DIR}
|
||||||
|
|
||||||
|
LIBND4J_BUILD_DIR=${BASE_DIR}/build_pi
|
||||||
|
|
||||||
|
#for some downloads
|
||||||
|
XRTACT_STRIP="--strip-components=1"
|
||||||
|
|
||||||
|
HAS_ARMCOMPUTE=1
|
||||||
|
mkdir -p ${BASE_DIR}
|
||||||
|
mkdir -p ${THIRD_PARTY}
|
||||||
|
|
||||||
|
#change directory to base
|
||||||
|
cd $BASE_DIR
|
||||||
|
|
||||||
|
function message {
|
||||||
|
echo "BUILDER:::: ${@}"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
function check_requirements {
|
||||||
|
for i in "${@}"
|
||||||
|
do
|
||||||
|
if [ ! -e "$i" ]; then
|
||||||
|
message "missing: ${i}"
|
||||||
|
exit -2
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
function download_extract {
|
||||||
|
#$1 is url #2 is dir $3 is extract argument
|
||||||
|
if [ ! -f ${2}_file ]; then
|
||||||
|
message "download"
|
||||||
|
wget --quiet --show-progress -O ${2}_file ${1}
|
||||||
|
fi
|
||||||
|
|
||||||
|
message "extract"
|
||||||
|
#extract
|
||||||
|
mkdir -p ${2}
|
||||||
|
command="tar -xzf ${2}_file --directory=${2} ${3} "
|
||||||
|
message $command
|
||||||
|
$command
|
||||||
|
|
||||||
|
check_requirements "${2}"
|
||||||
|
}
|
||||||
|
|
||||||
|
function git_check {
|
||||||
|
#$1 is url #$2 is dir #$3 is tag or branch if optional
|
||||||
|
command="git clone --quiet ${1} ${2}"
|
||||||
|
message "$command"
|
||||||
|
$command
|
||||||
|
if [ -n "$3" ]; then
|
||||||
|
cd ${2}
|
||||||
|
command="git checkout ${3}"
|
||||||
|
message "$command"
|
||||||
|
$command
|
||||||
|
cd ${BASE_DIR}
|
||||||
|
fi
|
||||||
|
check_requirements "${2}"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if [ ! -d ${CROSS_COMPILER_DIR} ]; then
|
||||||
|
#out file
|
||||||
|
message "download CROSS_COMPILER"
|
||||||
|
download_extract ${CROSS_COMPILER_URL} ${CROSS_COMPILER_DIR} ${XRTACT_STRIP}
|
||||||
|
fi
|
||||||
|
|
||||||
|
#useful exports
|
||||||
|
export PI_FOLDER=${CROSS_COMPILER_DIR}
|
||||||
|
export RPI_BIN=${PI_FOLDER}/bin/arm-linux-gnueabihf
|
||||||
|
export PI_SYS_ROOT=${PI_FOLDER}/arm-linux-gnueabihf/libc
|
||||||
|
export LD_LIBRARY_PATH=${PI_FOLDER}/lib:$LD_LIBRARY_PATH
|
||||||
|
export CC=${RPI_BIN}-gcc
|
||||||
|
export FC=${RPI_BIN}-gfortran
|
||||||
|
export CXX=${RPI_BIN}-g++
|
||||||
|
export CPP=${RPI_BIN}-cpp
|
||||||
|
export RANLIB=${RPI_BIN}-gcc-ranlib
|
||||||
|
export LD="${RPI_BIN}-ld"
|
||||||
|
export AR="${RPI_BIN}-ar"
|
||||||
|
|
||||||
|
|
||||||
|
#lets build OpenBlas
|
||||||
|
if [ ! -d "${OPENBLAS_DIR}" ]; then
|
||||||
|
message "download OpenBLAS"
|
||||||
|
git_check "${OPENBLAS_GIT_URL}" "${OPENBLAS_DIR}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -f "${THIRD_PARTY}/lib/libopenblas.so" ]; then
|
||||||
|
message "build and install OpenBLAS"
|
||||||
|
cd ${OPENBLAS_DIR}
|
||||||
|
|
||||||
|
command="make TARGET=${BLAS_TARGET_NAME} HOSTCC=gcc CC=${CC} USE_THREAD=0 NOFORTRAN=1 CFLAGS=--sysroot=${PI_SYS_ROOT} LDFLAGS=\"-L${PI_SYS_ROOT}/../lib/ -lm\" &>/dev/null"
|
||||||
|
message $command
|
||||||
|
eval $command
|
||||||
|
message "install it"
|
||||||
|
command="make PREFIX=${THIRD_PARTY} install"
|
||||||
|
message $command
|
||||||
|
$command
|
||||||
|
cd $BASE_DIR
|
||||||
|
|
||||||
|
fi
|
||||||
|
check_requirements ${THIRD_PARTY}/lib/libopenblas.so
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if [ ! -d ${SCONS_LOCAL_DIR} ]; then
|
||||||
|
#out file
|
||||||
|
message "download Scons local"
|
||||||
|
download_extract ${SCONS_LOCAL_URL} ${SCONS_LOCAL_DIR}
|
||||||
|
fi
|
||||||
|
check_requirements ${SCONS_LOCAL_DIR}/scons.py
|
||||||
|
|
||||||
|
|
||||||
|
if [ ! -d "${ARMCOMPUTE_DIR}" ]; then
|
||||||
|
message "download ArmCompute Source"
|
||||||
|
git_check ${ARMCOMPUTE_GIT_URL} "${ARMCOMPUTE_DIR}" "tags/${ARMCOMPUTE_TAG}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
#build armcompute
|
||||||
|
if [ ! -f "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" ]; then
|
||||||
|
message "build arm compute"
|
||||||
|
cd ${ARMCOMPUTE_DIR}
|
||||||
|
command="CC=gcc CXX=g++ python3 ${SCONS_LOCAL_DIR}/scons.py Werror=1 -j$(nproc) toolchain_prefix=${RPI_BIN}- debug=${ARMCOMPUTE_DEBUG} neon=1 opencl=0 extra_cxx_flags=-fPIC os=linux build=cross_compile arch=${ARMCOMPUTE_TARGET} &>/dev/null"
|
||||||
|
message $command
|
||||||
|
eval $command
|
||||||
|
cd ${BASE_DIR}
|
||||||
|
fi
|
||||||
|
check_requirements "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" "${ARMCOMPUTE_DIR}/build/libarm_compute_core-static.a"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
message "build cmake for LIBND4J. output: ${LIBND4J_BUILD_DIR}"
|
||||||
|
|
||||||
|
TOOLCHAIN=${LIBND4J_SRC_DIR}/cmake/rpi.cmake
|
||||||
|
cmake_cmd="${CMAKE} -G \"Unix Makefiles\" -B${LIBND4J_BUILD_DIR} -S${LIBND4J_SRC_DIR} -DCMAKE_BUILD_TYPE=${LIBND4J_BUILD_MODE} -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN} -DCMAKE_VERBOSE_MAKEFILE:BOOL=ON -DSD_ALL_OPS=true -DSD_CPU=true -DSD_LIBRARY_NAME=nd4jcpu -DSD_BUILD_TESTS=ON -DSD_ARM_BUILD=true -DOPENBLAS_PATH=${THIRD_PARTY} -DSD_ARCH=${TARGET} -DARMCOMPUTE_ROOT=${ARMCOMPUTE_DIR} -DHELPERS_armcompute=${HAS_ARMCOMPUTE}"
|
||||||
|
message $cmake_cmd
|
||||||
|
eval $cmake_cmd
|
||||||
|
|
||||||
|
#build
|
||||||
|
message "lets build"
|
||||||
|
|
||||||
|
cd ${LIBND4J_BUILD_DIR}
|
||||||
|
make -j $(nproc)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -52,14 +52,19 @@ elseif(WIN32)
|
||||||
set(CMAKE_CXX_FLAGS " -fPIC")
|
set(CMAKE_CXX_FLAGS " -fPIC")
|
||||||
endif()
|
endif()
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
|
|
||||||
set(CMAKE_CXX_FLAGS " -fPIC")
|
set(CMAKE_CXX_FLAGS " -fPIC")
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
|
||||||
|
IF(${SD_ARCH} MATCHES "arm*")
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=${SD_ARCH}")
|
||||||
|
else()
|
||||||
|
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
|
||||||
|
|
||||||
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
|
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*")
|
||||||
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native")
|
set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native")
|
||||||
else()
|
else()
|
||||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native")
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native")
|
||||||
endif()
|
endif()
|
||||||
|
endif()
|
||||||
if (SD_CPU AND SD_SANITIZE)
|
if (SD_CPU AND SD_SANITIZE)
|
||||||
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address")
|
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address")
|
||||||
else()
|
else()
|
||||||
|
@ -130,7 +135,7 @@ if (SD_CPU)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_executable(runtests ${TEST_SOURCES})
|
add_executable(runtests ${TEST_SOURCES})
|
||||||
target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} gtest gtest_main)
|
target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} ${ARMCOMPUTE_LIBRARIES} gtest gtest_main)
|
||||||
elseif(SD_CUDA)
|
elseif(SD_CUDA)
|
||||||
|
|
||||||
add_executable(runtests ${TEST_SOURCES})
|
add_executable(runtests ${TEST_SOURCES})
|
||||||
|
|
|
@ -1113,7 +1113,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
#if 0
|
||||||
|
exp.printIndexedBuffer("Expected");
|
||||||
|
z->printIndexedBuffer("Z");
|
||||||
|
#endif
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
@ -1132,7 +1135,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
#if 0
|
||||||
|
exp.printIndexedBuffer("Expected");
|
||||||
|
z->printIndexedBuffer("Z");
|
||||||
|
#endif
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
@ -1151,7 +1157,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) {
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
|
|
||||||
auto z = result.at(0);
|
auto z = result.at(0);
|
||||||
|
#if 0
|
||||||
|
exp.printIndexedBuffer("Expected");
|
||||||
|
z->printIndexedBuffer("Z");
|
||||||
|
#endif
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
}
|
}
|
||||||
|
@ -1204,7 +1213,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) {
|
||||||
auto* output = results.at(0);
|
auto* output = results.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
|
#if 0
|
||||||
|
expOutput.printIndexedBuffer("expOutput");
|
||||||
|
output->printIndexedBuffer("output");
|
||||||
|
#endif
|
||||||
ASSERT_TRUE(expOutput.isSameShape(output));
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
||||||
ASSERT_TRUE(expOutput.equalsTo(output));
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
||||||
}
|
}
|
||||||
|
|
|
@ -244,7 +244,8 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) {
|
||||||
#ifdef _RELEASE
|
#ifdef _RELEASE
|
||||||
TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
|
TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
|
||||||
// [2,1,135079944,1,1,8192,1,99]
|
// [2,1,135079944,1,1,8192,1,99]
|
||||||
auto initial = NDArrayFactory::create<float>('c', {1, 135079944});
|
constexpr int sizeX= 10*1000*1000;
|
||||||
|
auto initial = NDArrayFactory::create<float>('c', {1, sizeX});
|
||||||
initial = 1.0f;
|
initial = 1.0f;
|
||||||
auto exp = initial.dup();
|
auto exp = initial.dup();
|
||||||
auto neg = initial.like();
|
auto neg = initial.like();
|
||||||
|
@ -254,7 +255,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
|
||||||
auto enc_result = enc.evaluate({&initial}, {0.5f});
|
auto enc_result = enc.evaluate({&initial}, {0.5f});
|
||||||
auto encoded = enc_result.at(1);
|
auto encoded = enc_result.at(1);
|
||||||
|
|
||||||
ASSERT_EQ(135079944 + 4, encoded->lengthOf());
|
ASSERT_EQ(sizeX + 4, encoded->lengthOf());
|
||||||
ASSERT_NE(exp, initial);
|
ASSERT_NE(exp, initial);
|
||||||
/*
|
/*
|
||||||
for (int e = 0; e < initial.lengthOf(); e++) {
|
for (int e = 0; e < initial.lengthOf(); e++) {
|
||||||
|
@ -419,3 +420,4 @@ TEST_F(DeclarableOpsTests19, test_squeeze_1) {
|
||||||
auto status = op.execute({&x}, {&e}, {axis});
|
auto status = op.execute({&x}, {&e}, {axis});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1557,8 +1557,6 @@ TEST_F(DeclarableOpsTests5, trace_test1) {
|
||||||
// exp.printIndexedBuffer("EXP TRACE");
|
// exp.printIndexedBuffer("EXP TRACE");
|
||||||
// output->printIndexedBuffer("OUT TRACE");
|
// output->printIndexedBuffer("OUT TRACE");
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1575,8 +1573,6 @@ TEST_F(DeclarableOpsTests5, trace_test2) {
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1593,8 +1589,6 @@ TEST_F(DeclarableOpsTests5, trace_test3) {
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1611,8 +1605,6 @@ TEST_F(DeclarableOpsTests5, trace_test4) {
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1629,8 +1621,6 @@ TEST_F(DeclarableOpsTests5, trace_test5) {
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1638,22 +1628,15 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {2, 2, 2});
|
auto input = NDArrayFactory::create<double>('c', {2, 2, 2});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
NDArray exp1 = input.dup();
|
||||||
|
NDArray exp2('c',{2,2,2}, {5,6,7,8, 1,2,3,4}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
sd::ops::random_shuffle op;
|
sd::ops::random_shuffle op;
|
||||||
auto results = op.evaluate({&input});
|
auto results = op.evaluate({&input});
|
||||||
auto output = results.at(0);
|
auto output = results.at(0);
|
||||||
|
|
||||||
bool haveZeros = false;
|
|
||||||
for(int i = 0; i < output->lengthOf(); ++i)
|
|
||||||
if(output->e<float>(i) == (float)0.)
|
|
||||||
haveZeros = true;
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
ASSERT_TRUE(input.isSameShape(output));
|
ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2));
|
||||||
ASSERT_TRUE(!input.equalsTo(output));
|
|
||||||
ASSERT_TRUE(!haveZeros);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1661,16 +1644,14 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test2) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {1, 3, 2});
|
auto input = NDArrayFactory::create<double>('c', {1, 3, 2});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
NDArray exp1 = input.dup();
|
||||||
|
|
||||||
sd::ops::random_shuffle op;
|
sd::ops::random_shuffle op;
|
||||||
auto results = op.evaluate({&input});
|
auto results = op.evaluate({&input});
|
||||||
auto output = results.at(0);
|
auto output = results.at(0);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
ASSERT_TRUE(input.isSameShape(output));
|
ASSERT_TRUE(output->equalsTo(exp1));
|
||||||
ASSERT_TRUE(input.equalsTo(output));
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
@ -1678,129 +1659,132 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {3, 2, 1});
|
auto input = NDArrayFactory::create<double>('c', {3, 2, 1});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
NDArray exp1 = input.dup();
|
||||||
|
NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
sd::ops::random_shuffle op;
|
sd::ops::random_shuffle op;
|
||||||
auto results = op.evaluate({&input});
|
|
||||||
auto output = results.at(0);
|
|
||||||
|
|
||||||
bool haveZeros = false;
|
|
||||||
for(int i = 0; i < output->lengthOf(); ++i)
|
|
||||||
if(output->e<float>(i) == (float)0.)
|
|
||||||
haveZeros = true;
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
|
||||||
ASSERT_TRUE(input.isSameShape(output));
|
|
||||||
ASSERT_TRUE(!input.equalsTo(output));
|
|
||||||
ASSERT_TRUE(!haveZeros);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests5, random_shuffle_test04) {
|
|
||||||
auto input = NDArrayFactory::create<double>('c', {4});
|
|
||||||
input.linspace(1);
|
|
||||||
|
|
||||||
sd::ops::random_shuffle op;
|
|
||||||
//NDArray* output;
|
|
||||||
auto results = op.evaluate({&input}, {}, {}, {}, {}, true);
|
auto results = op.evaluate({&input}, {}, {}, {}, {}, true);
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
auto output = &input; //results.at(0);
|
ASSERT_TRUE(input.equalsTo(exp1) || input.equalsTo(exp2) || input.equalsTo(exp3)
|
||||||
bool haveZeros = false;
|
|| input.equalsTo(exp4) || input.equalsTo(exp5) || input.equalsTo(exp6));
|
||||||
for(int i = 0; i < output->lengthOf(); ++i)
|
|
||||||
if(output->e<float>(i) == (float)0.)
|
|
||||||
haveZeros = true;
|
|
||||||
|
|
||||||
ASSERT_TRUE(input.isSameShape(output));
|
|
||||||
//ASSERT_TRUE(!input.equalsTo(output));
|
|
||||||
ASSERT_TRUE(!haveZeros);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
|
TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
|
||||||
auto input = NDArrayFactory::create<double>('c', {4});
|
|
||||||
|
auto input = NDArrayFactory::create<double>('c', {3, 2, 1});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
NDArray exp1 = input.dup();
|
||||||
|
NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE);
|
||||||
|
NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE);
|
||||||
|
|
||||||
sd::ops::random_shuffle op;
|
sd::ops::random_shuffle op;
|
||||||
//NDArray* output;
|
|
||||||
auto results = op.evaluate({&input});
|
auto results = op.evaluate({&input});
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
|
||||||
auto output = results.at(0);
|
auto output = results.at(0);
|
||||||
bool haveZeros = false;
|
|
||||||
for(int i = 0; i < output->lengthOf(); ++i)
|
|
||||||
if(output->e<float>(i) == (float)0.)
|
|
||||||
haveZeros = true;
|
|
||||||
|
|
||||||
ASSERT_TRUE(input.isSameShape(output));
|
|
||||||
//ASSERT_TRUE(!input.equalsTo(output));
|
|
||||||
ASSERT_TRUE(!haveZeros);
|
|
||||||
|
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
|
ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2) || output->equalsTo(exp3)
|
||||||
|
|| output->equalsTo(exp4) || output->equalsTo(exp5) || output->equalsTo(exp6));
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests5, random_shuffle_test5) {
|
TEST_F(DeclarableOpsTests5, random_shuffle_test5) {
|
||||||
|
auto input = NDArrayFactory::create<int>('c', {4});
|
||||||
auto input = NDArrayFactory::create<double>('c', {4,1});
|
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
sd::ops::random_shuffle op;
|
sd::ops::random_shuffle op;
|
||||||
auto results = op.evaluate({&input});
|
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
|
||||||
auto output = results.at(0);
|
auto output = results.at(0);
|
||||||
|
// output->printBuffer();
|
||||||
bool haveZeros = false;
|
|
||||||
for(int i = 0; i < output->lengthOf(); ++i)
|
|
||||||
if(output->e<float>(i) == (float)0.)
|
|
||||||
haveZeros = true;
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
ASSERT_TRUE(input.isSameShape(output));
|
// ASSERT_TRUE(!output->equalsTo(input));
|
||||||
ASSERT_TRUE(!input.equalsTo(output));
|
|
||||||
ASSERT_TRUE(!haveZeros);
|
|
||||||
|
|
||||||
|
|
||||||
|
bool hasDublicates = false;
|
||||||
|
for(int i = 0; i < output->lengthOf() - 1; ++i)
|
||||||
|
for(int j = i+1; j < output->lengthOf(); ++j)
|
||||||
|
if(output->t<int>(i) == output->t<int>(j)) {
|
||||||
|
hasDublicates = true;
|
||||||
|
i = output->lengthOf();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(!hasDublicates);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests5, random_shuffle_test6) {
|
TEST_F(DeclarableOpsTests5, random_shuffle_test6) {
|
||||||
|
auto input = NDArrayFactory::create<int>('c', {4,1,1});
|
||||||
auto input = NDArrayFactory::create<double>('c', {4,1,1});
|
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
sd::ops::random_shuffle op;
|
sd::ops::random_shuffle op;
|
||||||
auto results = op.evaluate({&input});
|
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
|
||||||
auto output = results.at(0);
|
auto output = results.at(0);
|
||||||
|
|
||||||
bool haveZeros = false;
|
|
||||||
for(int i = 0; i < output->lengthOf(); ++i)
|
|
||||||
if(output->e<float>(i) == (float)0.)
|
|
||||||
haveZeros = true;
|
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
ASSERT_TRUE(input.isSameShape(output));
|
// ASSERT_TRUE(!output->equalsTo(input));
|
||||||
ASSERT_TRUE(!input.equalsTo(output));
|
|
||||||
ASSERT_TRUE(!haveZeros);
|
|
||||||
|
|
||||||
|
|
||||||
|
bool hasDublicates = false;
|
||||||
|
for(int i = 0; i < output->lengthOf() - 1; ++i)
|
||||||
|
for(int j = i+1; j < output->lengthOf(); ++j)
|
||||||
|
if(output->t<int>(i) == output->t<int>(j)) {
|
||||||
|
hasDublicates = true;
|
||||||
|
i = output->lengthOf();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(!hasDublicates);
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests5, random_shuffle_test7) {
|
TEST_F(DeclarableOpsTests5, random_shuffle_test7) {
|
||||||
|
auto input = NDArrayFactory::create<int>('c', {16010});
|
||||||
auto input = NDArrayFactory::create<double>('c', {1,4});
|
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
auto exp = NDArrayFactory::create<double>('c', {1,4}, {1, 2, 3, 4});
|
|
||||||
|
|
||||||
sd::ops::random_shuffle op;
|
sd::ops::random_shuffle op;
|
||||||
auto results = op.evaluate({&input});
|
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
|
||||||
auto output = results.at(0);
|
auto output = results.at(0);
|
||||||
|
// output->printBuffer();
|
||||||
ASSERT_EQ(Status::OK(), results.status());
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
ASSERT_TRUE(input.isSameShape(output));
|
ASSERT_TRUE(!output->equalsTo(input));
|
||||||
ASSERT_TRUE(input.equalsTo(output));
|
|
||||||
|
|
||||||
|
auto vec1 = input.getBufferAsVector<int>();
|
||||||
|
auto vec2 = output->getBufferAsVector<int>();
|
||||||
|
std::sort(vec2.begin(), vec2.end());
|
||||||
|
ASSERT_TRUE(std::equal(vec1.begin(), vec1.end(), vec2.begin()));
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests5, random_shuffle_test8) {
|
||||||
|
auto input = NDArrayFactory::create<int>('c', {1,4,1});
|
||||||
|
input.linspace(1);
|
||||||
|
NDArray inCopy = input.dup();
|
||||||
|
|
||||||
|
sd::ops::random_shuffle op;
|
||||||
|
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
|
||||||
|
ASSERT_EQ(Status::OK(), results.status());
|
||||||
|
ASSERT_TRUE(input.equalsTo(inCopy));
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests5, random_shuffle_test9) {
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
||||||
|
auto z = x.ulike();
|
||||||
|
|
||||||
|
sd::ops::random_shuffle op;
|
||||||
|
auto status = op.execute({&x}, {&z});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
auto vec = z.getBufferAsVector<int>();
|
||||||
|
std::sort(vec.begin(), vec.end());
|
||||||
|
ASSERT_EQ(std::vector<int>({1, 2, 3, 4}), vec);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -251,11 +251,10 @@ TEST_F(DeclarableOpsTests9, concat_test1) {
|
||||||
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
|
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||||
auto output = result.at(0);
|
auto output = result.at(0);
|
||||||
|
// output->printCurrentBuffer<float>(false);
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(output));
|
ASSERT_TRUE(exp.isSameShape(output));
|
||||||
ASSERT_TRUE(exp.equalsTo(output));
|
ASSERT_TRUE(exp.equalsTo(output));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -317,7 +317,7 @@ void fill_random(sd::NDArray& arr) {
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void testLegacy(bool random) {
|
void testLegacy(bool random) {
|
||||||
#if 0
|
#if 0
|
||||||
int bases[] = { 3, 2, 4, 5, 7 };
|
int bases[] = { 3, 2, 4, 5, 7 };
|
||||||
|
@ -364,7 +364,7 @@ int k = 4;
|
||||||
#endif
|
#endif
|
||||||
auto dim = NDArrayFactory::create<int>(dimension);
|
auto dim = NDArrayFactory::create<int>(dimension);
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
nd4j_printf("C(N:%d K:%d) \n", N, k);
|
nd4j_printf("C(N:%d K:%d) \n", N, k);
|
||||||
dim.printIndexedBuffer("Dimension");
|
dim.printIndexedBuffer("Dimension");
|
||||||
for (int xind : dimension) {
|
for (int xind : dimension) {
|
||||||
|
@ -385,7 +385,7 @@ for (int e = 0; e < Loop; e++) {
|
||||||
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
|
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
|
||||||
values.emplace_back(outerTime);
|
values.emplace_back(outerTime);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::sort(values.begin(), values.end());
|
std::sort(values.begin(), values.end());
|
||||||
|
|
||||||
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
||||||
|
@ -411,7 +411,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
|
||||||
constexpr int N = 5;
|
constexpr int N = 5;
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
for (int i = 0; i < N; i++) {
|
for (int i = 0; i < N; i++) {
|
||||||
arr_dimensions.push_back(bases[i]);
|
arr_dimensions.push_back(bases[i]);
|
||||||
}
|
}
|
||||||
|
@ -451,7 +451,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
|
||||||
#endif
|
#endif
|
||||||
auto dim = NDArrayFactory::create<int>(dimension);
|
auto dim = NDArrayFactory::create<int>(dimension);
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
nd4j_printf("C(N:%d K:%d) \n", N, k);
|
nd4j_printf("C(N:%d K:%d) \n", N, k);
|
||||||
dim.printIndexedBuffer("Dimension");
|
dim.printIndexedBuffer("Dimension");
|
||||||
for (int xind : dimension) {
|
for (int xind : dimension) {
|
||||||
|
@ -477,14 +477,14 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
|
||||||
//check for the correctness
|
//check for the correctness
|
||||||
NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create<Nd4jLong>('c', output_bases) : NDArrayFactory::create<Nd4jLong>(0);
|
NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create<Nd4jLong>('c', output_bases) : NDArrayFactory::create<Nd4jLong>(0);
|
||||||
original_argmax(x, dimension, exp);
|
original_argmax(x, dimension, exp);
|
||||||
|
|
||||||
|
|
||||||
#if 0// defined(DEBUG)
|
#if 0// defined(DEBUG)
|
||||||
x.printIndexedBuffer("X");
|
x.printIndexedBuffer("X");
|
||||||
exp.printIndexedBuffer("Expected");
|
exp.printIndexedBuffer("Expected");
|
||||||
z->printIndexedBuffer("Z");
|
z->printIndexedBuffer("Z");
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
ASSERT_TRUE(exp.isSameShape(z));
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
}
|
}
|
||||||
|
@ -505,7 +505,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfLinspace) {
|
||||||
testNewReduction(false, test_corr);
|
testNewReduction(false, test_corr);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
|
TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
|
||||||
testNewReduction(true, test_corr);
|
testNewReduction(true, test_corr);
|
||||||
}
|
}
|
||||||
|
@ -513,7 +513,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
|
||||||
TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) {
|
TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) {
|
||||||
testNewReduction(true, test_corr, 'f');
|
testNewReduction(true, test_corr, 'f');
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(DEBUG)
|
#if !defined(DEBUG)
|
||||||
TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) {
|
TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) {
|
||||||
testLegacy(false);
|
testLegacy(false);
|
||||||
|
@ -1062,39 +1062,6 @@ TEST_F(PlaygroundTests, my) {
|
||||||
delete variableSpace;
|
delete variableSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(PlaygroundTests, my) {
|
|
||||||
|
|
||||||
int N = 100;
|
|
||||||
int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
|
||||||
int oH=128,oW=128;
|
|
||||||
|
|
||||||
int paddingMode = 1; // 1-SAME, 0-VALID;
|
|
||||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
|
||||||
|
|
||||||
// NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32);
|
|
||||||
// NDArray output('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32);
|
|
||||||
NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32);
|
|
||||||
NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32);
|
|
||||||
// NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
|
||||||
NDArray weights('c', {oC, iC, kH, kW}, sd::DataType::FLOAT32);
|
|
||||||
NDArray bias('c', {oC}, sd::DataType::FLOAT32);
|
|
||||||
|
|
||||||
input = 5.;
|
|
||||||
weights = 3.;
|
|
||||||
bias = 1.;
|
|
||||||
|
|
||||||
sd::ops::conv2d op;
|
|
||||||
auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
||||||
|
|
||||||
auto timeStart = std::chrono::system_clock::now();
|
|
||||||
for (int i = 0; i < N; ++i)
|
|
||||||
err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
|
||||||
auto timeEnd = std::chrono::system_clock::now();
|
|
||||||
auto time = std::chrono::duration_cast<std::chrono::microseconds> ((timeEnd - timeStart) / N).count();
|
|
||||||
|
|
||||||
printf("time: %i \n", time);
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(PlaygroundTests, lstmLayerCellBp_1) {
|
TEST_F(PlaygroundTests, lstmLayerCellBp_1) {
|
||||||
|
|
||||||
|
@ -1690,6 +1657,52 @@ TEST_F(DeclarableOpsTests15, gru_bp_1) {
|
||||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#include<ops/declarable/helpers/transforms.h>
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(PlaygroundTests, my) {
|
||||||
|
|
||||||
|
const int N = 10;
|
||||||
|
|
||||||
|
NDArray input('c', {8000000}, sd::DataType::INT32);
|
||||||
|
input.linspace(1);
|
||||||
|
NDArray output = input.dup();
|
||||||
|
|
||||||
|
|
||||||
|
sd::graph::RandomGenerator rng;
|
||||||
|
|
||||||
|
sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true);
|
||||||
|
|
||||||
|
// auto timeStart = std::chrono::system_clock::now();
|
||||||
|
// for (int i = 0; i < N; ++i)
|
||||||
|
// sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true);
|
||||||
|
// auto timeEnd = std::chrono::system_clock::now();
|
||||||
|
// auto time = std::chrono::duration_cast<std::chrono::microseconds> ((timeEnd - timeStart) / N).count();
|
||||||
|
// printf("time: %i \n", time);
|
||||||
|
|
||||||
|
// bool hasDublicates = false;
|
||||||
|
// for(int i = 0; i < output.lengthOf() - 1; ++i)
|
||||||
|
// for(int j = i+1; j < output.lengthOf(); ++j)
|
||||||
|
// if(output.t<int>(i) == output.t<int>(j)) {
|
||||||
|
// hasDublicates = true;
|
||||||
|
// i = output.lengthOf();
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
|
||||||
|
ASSERT_TRUE(!input.equalsTo(output));
|
||||||
|
|
||||||
|
bool hasDublicates = false;
|
||||||
|
for(int i = 0; i < input.lengthOf() - 1; ++i)
|
||||||
|
for(int j = i+1; j < input.lengthOf(); ++j)
|
||||||
|
if(input.t<int>(i) == input.t<int>(j)) {
|
||||||
|
hasDublicates = true;
|
||||||
|
i = input.lengthOf();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(!hasDublicates);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,93 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author raver119@gmail.com
|
|
||||||
//
|
|
||||||
|
|
||||||
#ifndef LIBND4J_SESSIONLOCALTESTS_H
|
|
||||||
#define LIBND4J_SESSIONLOCALTESTS_H
|
|
||||||
|
|
||||||
#include "testlayers.h"
|
|
||||||
#include <array/NDArrayFactory.h>
|
|
||||||
#include <graph/SessionLocalStorage.h>
|
|
||||||
|
|
||||||
using namespace sd::graph;
|
|
||||||
|
|
||||||
class SessionLocalTests : public testing::Test {
|
|
||||||
public:
|
|
||||||
|
|
||||||
};
|
|
||||||
|
|
||||||
TEST_F(SessionLocalTests, BasicTests_1) {
|
|
||||||
VariableSpace variableSpace;
|
|
||||||
SessionLocalStorage storage(&variableSpace, nullptr);
|
|
||||||
|
|
||||||
if (omp_get_max_threads() <= 1)
|
|
||||||
return;
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(4)
|
|
||||||
for (int e = 0; e < 4; e++) {
|
|
||||||
storage.startSession();
|
|
||||||
}
|
|
||||||
|
|
||||||
ASSERT_EQ(4, storage.numberOfSessions());
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(4)
|
|
||||||
for (int e = 0; e < 4; e++) {
|
|
||||||
storage.endSession();
|
|
||||||
}
|
|
||||||
|
|
||||||
ASSERT_EQ(0, storage.numberOfSessions());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(SessionLocalTests, BasicTests_2) {
|
|
||||||
VariableSpace variableSpace;
|
|
||||||
SessionLocalStorage storage(&variableSpace, nullptr);
|
|
||||||
|
|
||||||
if (omp_get_max_threads() <= 1)
|
|
||||||
return;
|
|
||||||
|
|
||||||
auto alpha = sd::NDArrayFactory::create_<float>('c',{5,5});
|
|
||||||
alpha->assign(0.0);
|
|
||||||
|
|
||||||
variableSpace.putVariable(-1, alpha);
|
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR_THREADS(4)
|
|
||||||
for (int e = 0; e < 4; e++) {
|
|
||||||
storage.startSession();
|
|
||||||
|
|
||||||
auto varSpace = storage.localVariableSpace();
|
|
||||||
|
|
||||||
auto arr = varSpace->getVariable(-1)->getNDArray();
|
|
||||||
arr->applyScalar(sd::scalar::Add, (float) e+1, *arr);
|
|
||||||
}
|
|
||||||
|
|
||||||
float lastValue = 0.0f;
|
|
||||||
for (int e = 1; e <= 4; e++) {
|
|
||||||
auto varSpace = storage.localVariableSpace((Nd4jLong) e);
|
|
||||||
|
|
||||||
auto arr = varSpace->getVariable(-1)->getNDArray();
|
|
||||||
|
|
||||||
//nd4j_printf("Last value: %f; Current value: %f\n", lastValue, arr->e(0));
|
|
||||||
|
|
||||||
ASSERT_NE(lastValue, arr->e<float>(0));
|
|
||||||
lastValue = arr->e<float>(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif //LIBND4J_SESSIONLOCALTESTS_H
|
|
|
@ -45,6 +45,21 @@ if ("${BUILD_MKLDNN}")
|
||||||
set(MKLDNN dnnl)
|
set(MKLDNN dnnl)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (${HELPERS_armcompute})
|
||||||
|
find_package(ARMCOMPUTE REQUIRED)
|
||||||
|
|
||||||
|
if(ARMCOMPUTE_FOUND)
|
||||||
|
message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}")
|
||||||
|
set(HAVE_ARMCOMPUTE 1)
|
||||||
|
# Add preprocessor definition for ARM Compute NEON
|
||||||
|
add_definitions(-DARMCOMPUTENEON_ENABLED)
|
||||||
|
#build our library with neon support
|
||||||
|
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon")
|
||||||
|
include_directories(${ARMCOMPUTE_INCLUDE})
|
||||||
|
endif()
|
||||||
|
|
||||||
|
endif()
|
||||||
|
|
||||||
# Download and unpack flatbuffers at configure time
|
# Download and unpack flatbuffers at configure time
|
||||||
configure_file(../../CMakeLists.txt.in flatbuffers-download/CMakeLists.txt)
|
configure_file(../../CMakeLists.txt.in flatbuffers-download/CMakeLists.txt)
|
||||||
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
|
execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" .
|
||||||
|
@ -217,6 +232,10 @@ if ("${BUILD_MKLDNN}")
|
||||||
file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../../include/ops/declarable/platform/mkldnn/*.cpp)
|
file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../../include/ops/declarable/platform/mkldnn/*.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(HAVE_ARMCOMPUTE)
|
||||||
|
file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/armcomputeUtils.h)
|
||||||
|
endif()
|
||||||
|
|
||||||
message("CPU backend")
|
message("CPU backend")
|
||||||
add_definitions(-D__CPUBLAS__=true)
|
add_definitions(-D__CPUBLAS__=true)
|
||||||
|
|
||||||
|
@ -276,8 +295,9 @@ endforeach(TMP_PATH)
|
||||||
|
|
||||||
|
|
||||||
add_executable(runtests ${LOOPS_SOURCES} ${LEGACY_SOURCES} ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
add_executable(runtests ${LOOPS_SOURCES} ${LEGACY_SOURCES} ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES}
|
||||||
|
${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES}
|
||||||
${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES})
|
${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES})
|
||||||
|
|
||||||
target_link_libraries(runtests gtest ${MKLDNN} gtest_main ${BLAS_LIBRARIES})
|
target_link_libraries(runtests gtest ${MKLDNN} ${ARMCOMPUTE_LIBRARIES} gtest_main ${BLAS_LIBRARIES})
|
||||||
|
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -25,7 +25,7 @@
|
||||||
</parent>
|
</parent>
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
<groupId>org.eclipse</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<artifactId>python4j-parent</artifactId>
|
<artifactId>python4j-parent</artifactId>
|
||||||
<packaging>pom</packaging>
|
<packaging>pom</packaging>
|
||||||
<modules>
|
<modules>
|
||||||
|
@ -41,10 +41,14 @@
|
||||||
<scope>provided</scope>
|
<scope>provided</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
|
<groupId>org.slf4j</groupId>
|
||||||
|
<artifactId>slf4j-api</artifactId>
|
||||||
|
<version>1.6.6</version>
|
||||||
|
</dependency> <dependency>
|
||||||
<groupId>ch.qos.logback</groupId>
|
<groupId>ch.qos.logback</groupId>
|
||||||
<artifactId>logback-classic</artifactId>
|
<artifactId>logback-classic</artifactId>
|
||||||
<version>${logback.version}</version>
|
<version>${logback.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>junit</groupId>
|
<groupId>junit</groupId>
|
||||||
|
@ -62,5 +66,10 @@
|
||||||
<artifactId>jsr305</artifactId>
|
<artifactId>jsr305</artifactId>
|
||||||
<version>3.0.2</version>
|
<version>3.0.2</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.slf4j</groupId>
|
||||||
|
<artifactId>slf4j-api</artifactId>
|
||||||
|
<version>1.6.6</version>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
|
@ -21,7 +21,7 @@
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<parent>
|
<parent>
|
||||||
<artifactId>python4j-parent</artifactId>
|
<artifactId>python4j-parent</artifactId>
|
||||||
<groupId>org.eclipse</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
</parent>
|
</parent>
|
||||||
<packaging>jar</packaging>
|
<packaging>jar</packaging>
|
||||||
|
@ -39,6 +39,5 @@
|
||||||
<artifactId>cpython-platform</artifactId>
|
<artifactId>cpython-platform</artifactId>
|
||||||
<version>${cpython-platform.version}</version>
|
<version>${cpython-platform.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
|
||||||
</dependencies>
|
</dependencies>
|
||||||
</project>
|
</project>
|
|
@ -15,7 +15,7 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
package org.eclipse.python4j;
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
import org.bytedeco.cpython.PyObject;
|
import org.bytedeco.cpython.PyObject;
|
||||||
|
|
|
@ -14,13 +14,15 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.eclipse.python4j;
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
import javax.lang.model.SourceVersion;
|
import javax.lang.model.SourceVersion;
|
||||||
|
|
||||||
|
|
||||||
|
import java.io.Closeable;
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
|
import java.util.UUID;
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -46,6 +48,31 @@ public class PythonContextManager {
|
||||||
init();
|
init();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static class Context implements Closeable{
|
||||||
|
private final String name;
|
||||||
|
private final String previous;
|
||||||
|
private final boolean temp;
|
||||||
|
public Context(){
|
||||||
|
name = "temp_" + UUID.randomUUID().toString().replace("-", "_");
|
||||||
|
temp = true;
|
||||||
|
previous = getCurrentContext();
|
||||||
|
setContext(name);
|
||||||
|
}
|
||||||
|
public Context(String name){
|
||||||
|
this.name = name;
|
||||||
|
temp = false;
|
||||||
|
previous = getCurrentContext();
|
||||||
|
setContext(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close(){
|
||||||
|
setContext(previous);
|
||||||
|
if (temp) deleteContext(name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static void init() {
|
private static void init() {
|
||||||
if (init.get()) return;
|
if (init.get()) return;
|
||||||
new PythonExecutioner();
|
new PythonExecutioner();
|
||||||
|
@ -76,7 +103,18 @@ public class PythonContextManager {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static boolean validateContextName(String s) {
|
private static boolean validateContextName(String s) {
|
||||||
return SourceVersion.isIdentifier(s) && !s.startsWith(COLLAPSED_KEY);
|
for (int i=0; i<s.length(); i++){
|
||||||
|
char c = s.toLowerCase().charAt(i);
|
||||||
|
if (i == 0){
|
||||||
|
if (c >= '0' && c <= '9'){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!(c=='_' || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9'))){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static String getContextPrefix(String contextName) {
|
private static String getContextPrefix(String contextName) {
|
||||||
|
@ -190,6 +228,7 @@ public class PythonContextManager {
|
||||||
setContext(tempContext);
|
setContext(tempContext);
|
||||||
deleteContext(currContext);
|
deleteContext(currContext);
|
||||||
setContext(currContext);
|
setContext(currContext);
|
||||||
|
deleteContext(tempContext);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
|
@ -14,7 +14,7 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.eclipse.python4j;
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
|
@ -15,7 +15,7 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
package org.eclipse.python4j;
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
import org.bytedeco.cpython.PyObject;
|
import org.bytedeco.cpython.PyObject;
|
||||||
|
|
||||||
|
@ -42,7 +42,6 @@ public class PythonExecutioner {
|
||||||
private final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.eclipse.python4j.path";
|
private final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.eclipse.python4j.path";
|
||||||
private final static String JAVACPP_PYTHON_APPEND_TYPE = "org.eclipse.python4j.path.append";
|
private final static String JAVACPP_PYTHON_APPEND_TYPE = "org.eclipse.python4j.path.append";
|
||||||
private final static String DEFAULT_APPEND_TYPE = "before";
|
private final static String DEFAULT_APPEND_TYPE = "before";
|
||||||
|
|
||||||
static {
|
static {
|
||||||
init();
|
init();
|
||||||
}
|
}
|
||||||
|
@ -55,6 +54,11 @@ public class PythonExecutioner {
|
||||||
initPythonPath();
|
initPythonPath();
|
||||||
PyEval_InitThreads();
|
PyEval_InitThreads();
|
||||||
Py_InitializeEx(0);
|
Py_InitializeEx(0);
|
||||||
|
for (PythonType type: PythonTypes.get()){
|
||||||
|
type.init();
|
||||||
|
}
|
||||||
|
// Constructors of custom types may contain initialization code that should
|
||||||
|
// run on the main the thread.
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -110,6 +114,8 @@ public class PythonExecutioner {
|
||||||
getVariables(Arrays.asList(pyVars));
|
getVariables(Arrays.asList(pyVars));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Gets the variable with the given name from the interpreter.
|
* Gets the variable with the given name from the interpreter.
|
||||||
*
|
*
|
||||||
|
@ -205,9 +211,9 @@ public class PythonExecutioner {
|
||||||
*
|
*
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static List<PythonVariable> getAllVariables() {
|
public static PythonVariables getAllVariables() {
|
||||||
PythonGIL.assertThreadSafe();
|
PythonGIL.assertThreadSafe();
|
||||||
List<PythonVariable> ret = new ArrayList<>();
|
PythonVariables ret = new PythonVariables();
|
||||||
PyObject main = PyImport_ImportModule("__main__");
|
PyObject main = PyImport_ImportModule("__main__");
|
||||||
PyObject globals = PyModule_GetDict(main);
|
PyObject globals = PyModule_GetDict(main);
|
||||||
PyObject keys = PyDict_Keys(globals);
|
PyObject keys = PyDict_Keys(globals);
|
||||||
|
@ -259,7 +265,7 @@ public class PythonExecutioner {
|
||||||
* @param inputs
|
* @param inputs
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static List<PythonVariable> execAndReturnAllVariables(String code, List<PythonVariable> inputs) {
|
public static PythonVariables execAndReturnAllVariables(String code, List<PythonVariable> inputs) {
|
||||||
setVariables(inputs);
|
setVariables(inputs);
|
||||||
simpleExec(getWrappedCode(code));
|
simpleExec(getWrappedCode(code));
|
||||||
return getAllVariables();
|
return getAllVariables();
|
||||||
|
@ -271,7 +277,7 @@ public class PythonExecutioner {
|
||||||
* @param code
|
* @param code
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
public static List<PythonVariable> execAndReturnAllVariables(String code) {
|
public static PythonVariables execAndReturnAllVariables(String code) {
|
||||||
simpleExec(getWrappedCode(code));
|
simpleExec(getWrappedCode(code));
|
||||||
return getAllVariables();
|
return getAllVariables();
|
||||||
}
|
}
|
||||||
|
@ -279,25 +285,22 @@ public class PythonExecutioner {
|
||||||
private static synchronized void initPythonPath() {
|
private static synchronized void initPythonPath() {
|
||||||
try {
|
try {
|
||||||
String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY);
|
String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY);
|
||||||
|
|
||||||
|
List<File> packagesList = new ArrayList<>();
|
||||||
|
packagesList.addAll(Arrays.asList(cachePackages()));
|
||||||
|
for (PythonType type: PythonTypes.get()){
|
||||||
|
packagesList.addAll(Arrays.asList(type.packages()));
|
||||||
|
}
|
||||||
|
//// TODO: fix in javacpp
|
||||||
|
packagesList.add(new File(python.cachePackage(), "site-packages"));
|
||||||
|
|
||||||
|
File[] packages = packagesList.toArray(new File[0]);
|
||||||
|
|
||||||
if (path == null) {
|
if (path == null) {
|
||||||
File[] packages = cachePackages();
|
|
||||||
|
|
||||||
//// TODO: fix in javacpp
|
|
||||||
File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
|
|
||||||
File[] packages2 = new File[packages.length + 1];
|
|
||||||
for (int i = 0; i < packages.length; i++) {
|
|
||||||
//System.out.println(packages[i].getAbsolutePath());
|
|
||||||
packages2[i] = packages[i];
|
|
||||||
}
|
|
||||||
packages2[packages.length] = sitePackagesWindows;
|
|
||||||
//System.out.println(sitePackagesWindows.getAbsolutePath());
|
|
||||||
packages = packages2;
|
|
||||||
//////////
|
|
||||||
|
|
||||||
Py_SetPath(packages);
|
Py_SetPath(packages);
|
||||||
} else {
|
} else {
|
||||||
StringBuffer sb = new StringBuffer();
|
StringBuffer sb = new StringBuffer();
|
||||||
File[] packages = cachePackages();
|
|
||||||
JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase());
|
JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase());
|
||||||
switch (pathAppendValue) {
|
switch (pathAppendValue) {
|
||||||
case BEFORE:
|
case BEFORE:
|
|
@ -15,7 +15,7 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
package org.eclipse.python4j;
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
import org.bytedeco.cpython.PyObject;
|
import org.bytedeco.cpython.PyObject;
|
||||||
import org.bytedeco.javacpp.Pointer;
|
import org.bytedeco.javacpp.Pointer;
|
|
@ -14,11 +14,10 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.eclipse.python4j;
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
|
|
||||||
import org.bytedeco.cpython.PyThreadState;
|
import org.bytedeco.cpython.PyThreadState;
|
||||||
import org.omg.SendingContext.RunTime;
|
|
||||||
|
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
|
@ -90,4 +89,8 @@ public class PythonGIL implements AutoCloseable {
|
||||||
PyEval_SaveThread();
|
PyEval_SaveThread();
|
||||||
PyEval_RestoreThread(mainThreadState);
|
PyEval_RestoreThread(mainThreadState);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static boolean locked(){
|
||||||
|
return acquired.get();
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -14,31 +14,34 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.eclipse.python4j;
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
|
|
||||||
import lombok.Builder;
|
import lombok.Builder;
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
import lombok.NoArgsConstructor;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
|
||||||
import javax.annotation.Nonnull;
|
import javax.annotation.Nonnull;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor
|
|
||||||
/**
|
/**
|
||||||
* PythonJob is the right abstraction for executing multiple python scripts
|
* PythonJob is the right abstraction for executing multiple python scripts
|
||||||
* in a multi thread stateful environment. The setup-and-run mode allows your
|
* in a multi thread stateful environment. The setup-and-run mode allows your
|
||||||
* "setup" code (imports, model loading etc) to be executed only once.
|
* "setup" code (imports, model loading etc) to be executed only once.
|
||||||
*/
|
*/
|
||||||
|
@Data
|
||||||
|
@Slf4j
|
||||||
public class PythonJob {
|
public class PythonJob {
|
||||||
|
|
||||||
|
|
||||||
private String code;
|
private String code;
|
||||||
private String name;
|
private String name;
|
||||||
private String context;
|
private String context;
|
||||||
private boolean setupRunMode;
|
private final boolean setupRunMode;
|
||||||
private PythonObject runF;
|
private PythonObject runF;
|
||||||
|
private final AtomicBoolean setupDone = new AtomicBoolean(false);
|
||||||
|
|
||||||
static {
|
static {
|
||||||
new PythonExecutioner();
|
new PythonExecutioner();
|
||||||
|
@ -63,7 +66,6 @@ public class PythonJob {
|
||||||
if (PythonContextManager.hasContext(context)) {
|
if (PythonContextManager.hasContext(context)) {
|
||||||
throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!");
|
throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!");
|
||||||
}
|
}
|
||||||
if (setupRunMode) setup();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,17 +73,18 @@ public class PythonJob {
|
||||||
* Clears all variables in current context and calls setup()
|
* Clears all variables in current context and calls setup()
|
||||||
*/
|
*/
|
||||||
public void clearState(){
|
public void clearState(){
|
||||||
String context = this.context;
|
PythonContextManager.setContext(this.context);
|
||||||
PythonContextManager.setContext("main");
|
PythonContextManager.reset();
|
||||||
PythonContextManager.deleteContext(context);
|
setupDone.set(false);
|
||||||
this.context = context;
|
|
||||||
setup();
|
setup();
|
||||||
}
|
}
|
||||||
|
|
||||||
public void setup(){
|
public void setup(){
|
||||||
|
if (setupDone.get()) return;
|
||||||
try (PythonGIL gil = PythonGIL.lock()) {
|
try (PythonGIL gil = PythonGIL.lock()) {
|
||||||
PythonContextManager.setContext(context);
|
PythonContextManager.setContext(context);
|
||||||
PythonObject runF = PythonExecutioner.getVariable("run");
|
PythonObject runF = PythonExecutioner.getVariable("run");
|
||||||
|
|
||||||
if (runF == null || runF.isNone() || !Python.callable(runF)) {
|
if (runF == null || runF.isNone() || !Python.callable(runF)) {
|
||||||
PythonExecutioner.exec(code);
|
PythonExecutioner.exec(code);
|
||||||
runF = PythonExecutioner.getVariable("run");
|
runF = PythonExecutioner.getVariable("run");
|
||||||
|
@ -98,10 +101,12 @@ public class PythonJob {
|
||||||
if (!setupF.isNone()) {
|
if (!setupF.isNone()) {
|
||||||
setupF.call();
|
setupF.call();
|
||||||
}
|
}
|
||||||
|
setupDone.set(true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void exec(List<PythonVariable> inputs, List<PythonVariable> outputs) {
|
public void exec(List<PythonVariable> inputs, List<PythonVariable> outputs) {
|
||||||
|
if (setupRunMode)setup();
|
||||||
try (PythonGIL gil = PythonGIL.lock()) {
|
try (PythonGIL gil = PythonGIL.lock()) {
|
||||||
try (PythonGC _ = PythonGC.watch()) {
|
try (PythonGC _ = PythonGC.watch()) {
|
||||||
PythonContextManager.setContext(context);
|
PythonContextManager.setContext(context);
|
||||||
|
@ -139,6 +144,7 @@ public class PythonJob {
|
||||||
}
|
}
|
||||||
|
|
||||||
public List<PythonVariable> execAndReturnAllVariables(List<PythonVariable> inputs){
|
public List<PythonVariable> execAndReturnAllVariables(List<PythonVariable> inputs){
|
||||||
|
if (setupRunMode)setup();
|
||||||
try (PythonGIL gil = PythonGIL.lock()) {
|
try (PythonGIL gil = PythonGIL.lock()) {
|
||||||
try (PythonGC _ = PythonGC.watch()) {
|
try (PythonGC _ = PythonGC.watch()) {
|
||||||
PythonContextManager.setContext(context);
|
PythonContextManager.setContext(context);
|
|
@ -14,7 +14,7 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.eclipse.python4j;
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
|
|
||||||
import org.bytedeco.cpython.PyObject;
|
import org.bytedeco.cpython.PyObject;
|
||||||
|
@ -147,7 +147,8 @@ public class PythonObject {
|
||||||
}
|
}
|
||||||
PythonObject pyArgs;
|
PythonObject pyArgs;
|
||||||
PythonObject pyKwargs;
|
PythonObject pyKwargs;
|
||||||
if (args == null) {
|
|
||||||
|
if (args == null || args.isEmpty()) {
|
||||||
pyArgs = new PythonObject(PyTuple_New(0));
|
pyArgs = new PythonObject(PyTuple_New(0));
|
||||||
} else {
|
} else {
|
||||||
PythonObject argsList = PythonTypes.convert(args);
|
PythonObject argsList = PythonTypes.convert(args);
|
||||||
|
@ -158,6 +159,7 @@ public class PythonObject {
|
||||||
} else {
|
} else {
|
||||||
pyKwargs = PythonTypes.convert(kwargs);
|
pyKwargs = PythonTypes.convert(kwargs);
|
||||||
}
|
}
|
||||||
|
|
||||||
PythonObject ret = new PythonObject(
|
PythonObject ret = new PythonObject(
|
||||||
PyObject_Call(
|
PyObject_Call(
|
||||||
nativePythonObject,
|
nativePythonObject,
|
||||||
|
@ -165,7 +167,9 @@ public class PythonObject {
|
||||||
pyKwargs == null ? null : pyKwargs.nativePythonObject
|
pyKwargs == null ? null : pyKwargs.nativePythonObject
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
PythonGC.keep(ret);
|
PythonGC.keep(ret);
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,4 +245,48 @@ public class PythonObject {
|
||||||
PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject);
|
PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public PythonObject abs(){
|
||||||
|
return new PythonObject(PyNumber_Absolute(nativePythonObject));
|
||||||
|
}
|
||||||
|
public PythonObject add(PythonObject pythonObject){
|
||||||
|
return new PythonObject(PyNumber_Add(nativePythonObject, pythonObject.nativePythonObject));
|
||||||
|
}
|
||||||
|
public PythonObject sub(PythonObject pythonObject){
|
||||||
|
return new PythonObject(PyNumber_Subtract(nativePythonObject, pythonObject.nativePythonObject));
|
||||||
|
}
|
||||||
|
public PythonObject mod(PythonObject pythonObject){
|
||||||
|
return new PythonObject(PyNumber_Divmod(nativePythonObject, pythonObject.nativePythonObject));
|
||||||
|
}
|
||||||
|
public PythonObject mul(PythonObject pythonObject){
|
||||||
|
return new PythonObject(PyNumber_Multiply(nativePythonObject, pythonObject.nativePythonObject));
|
||||||
|
}
|
||||||
|
public PythonObject trueDiv(PythonObject pythonObject){
|
||||||
|
return new PythonObject(PyNumber_TrueDivide(nativePythonObject, pythonObject.nativePythonObject));
|
||||||
|
}
|
||||||
|
public PythonObject floorDiv(PythonObject pythonObject){
|
||||||
|
return new PythonObject(PyNumber_FloorDivide(nativePythonObject, pythonObject.nativePythonObject));
|
||||||
|
}
|
||||||
|
public PythonObject matMul(PythonObject pythonObject){
|
||||||
|
return new PythonObject(PyNumber_MatrixMultiply(nativePythonObject, pythonObject.nativePythonObject));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addi(PythonObject pythonObject){
|
||||||
|
PyNumber_InPlaceAdd(nativePythonObject, pythonObject.nativePythonObject);
|
||||||
|
}
|
||||||
|
public void subi(PythonObject pythonObject){
|
||||||
|
PyNumber_InPlaceSubtract(nativePythonObject, pythonObject.nativePythonObject);
|
||||||
|
}
|
||||||
|
public void muli(PythonObject pythonObject){
|
||||||
|
PyNumber_InPlaceMultiply(nativePythonObject, pythonObject.nativePythonObject);
|
||||||
|
}
|
||||||
|
public void trueDivi(PythonObject pythonObject){
|
||||||
|
PyNumber_InPlaceTrueDivide(nativePythonObject, pythonObject.nativePythonObject);
|
||||||
|
}
|
||||||
|
public void floorDivi(PythonObject pythonObject){
|
||||||
|
PyNumber_InPlaceFloorDivide(nativePythonObject, pythonObject.nativePythonObject);
|
||||||
|
}
|
||||||
|
public void matMuli(PythonObject pythonObject){
|
||||||
|
PyNumber_InPlaceMatrixMultiply(nativePythonObject, pythonObject.nativePythonObject);
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,127 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
|
import org.apache.commons.io.IOUtils;
|
||||||
|
import org.bytedeco.javacpp.Loader;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
|
|
||||||
|
public class PythonProcess {
|
||||||
|
private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
|
||||||
|
public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
|
||||||
|
String[] allArgs = new String[arguments.length + 1];
|
||||||
|
for (int i = 0; i < arguments.length; i++){
|
||||||
|
allArgs[i + 1] = arguments[i];
|
||||||
|
}
|
||||||
|
allArgs[0] = pythonExecutable;
|
||||||
|
ProcessBuilder pb = new ProcessBuilder(allArgs);
|
||||||
|
Process process = pb.start();
|
||||||
|
String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
|
||||||
|
process.waitFor();
|
||||||
|
return out;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void run(String... arguments)throws IOException, InterruptedException{
|
||||||
|
String[] allArgs = new String[arguments.length + 1];
|
||||||
|
for (int i = 0; i < arguments.length; i++){
|
||||||
|
allArgs[i + 1] = arguments[i];
|
||||||
|
}
|
||||||
|
allArgs[0] = pythonExecutable;
|
||||||
|
ProcessBuilder pb = new ProcessBuilder(allArgs);
|
||||||
|
pb.inheritIO().start().waitFor();
|
||||||
|
}
|
||||||
|
public static void pipInstall(String packageName) throws PythonException{
|
||||||
|
try{
|
||||||
|
run("-m", "pip", "install", packageName);
|
||||||
|
}catch(Exception e){
|
||||||
|
throw new PythonException("Error installing package " + packageName, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void pipInstall(String packageName, String version){
|
||||||
|
pipInstall(packageName + "==" + version);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void pipUninstall(String packageName) throws PythonException{
|
||||||
|
try{
|
||||||
|
run("-m", "pip", "uninstall", packageName);
|
||||||
|
}catch(Exception e){
|
||||||
|
throw new PythonException("Error uninstalling package " + packageName, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
public static void pipInstallFromGit(String gitRepoUrl){
|
||||||
|
if (!gitRepoUrl.contains("://")){
|
||||||
|
gitRepoUrl = "git://" + gitRepoUrl;
|
||||||
|
}
|
||||||
|
try{
|
||||||
|
run("-m", "pip", "install", "git+", gitRepoUrl);
|
||||||
|
}catch(Exception e){
|
||||||
|
throw new PythonException("Error installing package from " + gitRepoUrl, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String getPackageVersion(String packageName){
|
||||||
|
String out;
|
||||||
|
try{
|
||||||
|
out = runAndReturn("-m", "pip", "show", packageName);
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new PythonException("Error finding version for package " + packageName, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!out.contains("Version: ")){
|
||||||
|
throw new PythonException("Can't find package " + packageName);
|
||||||
|
}
|
||||||
|
String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
|
||||||
|
return pkgVersion;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean isPackageInstalled(String packageName){
|
||||||
|
try{
|
||||||
|
String out = runAndReturn("-m", "pip", "show", packageName);
|
||||||
|
return !out.isEmpty();
|
||||||
|
}catch (Exception e){
|
||||||
|
throw new PythonException("Error checking if package is installed: " +packageName, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void pipInstallFromRequirementsTxt(String path){
|
||||||
|
try{
|
||||||
|
run("-m", "pip", "install","-r", path);
|
||||||
|
}catch (Exception e){
|
||||||
|
throw new PythonException("Error installing packages from " + path, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void pipInstallFromSetupScript(String path, boolean inplace){
|
||||||
|
|
||||||
|
try{
|
||||||
|
run(path, inplace?"develop":"install");
|
||||||
|
}catch (Exception e){
|
||||||
|
throw new PythonException("Error installing package from " + path, e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -14,9 +14,11 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.eclipse.python4j;
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
|
||||||
public abstract class PythonType<T> {
|
public abstract class PythonType<T> {
|
||||||
|
|
||||||
private final String name;
|
private final String name;
|
||||||
|
@ -43,5 +45,25 @@ public abstract class PythonType<T> {
|
||||||
return name;
|
return name;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean equals(Object obj){
|
||||||
|
if (!(obj instanceof PythonType)){
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
PythonType other = (PythonType)obj;
|
||||||
|
return this.getClass().equals(other.getClass()) && this.name.equals(other.name);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject pythonType(){
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public File[] packages(){
|
||||||
|
return new File[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
public void init(){ //not to be called from constructor
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -14,11 +14,18 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.eclipse.python4j;
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
|
|
||||||
import org.bytedeco.cpython.PyObject;
|
import org.bytedeco.cpython.PyObject;
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import sun.nio.ch.DirectBuffer;
|
||||||
|
|
||||||
|
import java.lang.reflect.Field;
|
||||||
|
import java.nio.Buffer;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import static org.bytedeco.cpython.global.python.*;
|
import static org.bytedeco.cpython.global.python.*;
|
||||||
|
@ -28,7 +35,7 @@ public class PythonTypes {
|
||||||
|
|
||||||
|
|
||||||
private static List<PythonType> getPrimitiveTypes() {
|
private static List<PythonType> getPrimitiveTypes() {
|
||||||
return Arrays.<PythonType>asList(STR, INT, FLOAT, BOOL);
|
return Arrays.<PythonType>asList(STR, INT, FLOAT, BOOL, BYTES);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<PythonType> getCollectionTypes() {
|
private static List<PythonType> getCollectionTypes() {
|
||||||
|
@ -36,8 +43,13 @@ public class PythonTypes {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static List<PythonType> getExternalTypes() {
|
private static List<PythonType> getExternalTypes() {
|
||||||
//TODO service loader
|
List<PythonType> ret = new ArrayList<>();
|
||||||
return new ArrayList<>();
|
ServiceLoader<PythonType> sl = ServiceLoader.load(PythonType.class);
|
||||||
|
Iterator<PythonType> iter = sl.iterator();
|
||||||
|
while (iter.hasNext()) {
|
||||||
|
ret.add(iter.next());
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static List<PythonType> get() {
|
public static List<PythonType> get() {
|
||||||
|
@ -48,15 +60,17 @@ public class PythonTypes {
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static PythonType get(String name) {
|
public static <T> PythonType<T> get(String name) {
|
||||||
for (PythonType pt : get()) {
|
for (PythonType pt : get()) {
|
||||||
if (pt.getName().equals(name)) { // TODO use map instead?
|
if (pt.getName().equals(name)) { // TODO use map instead?
|
||||||
return pt;
|
return pt;
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
throw new PythonException("Unknown python type: " + name);
|
throw new PythonException("Unknown python type: " + name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
public static PythonType getPythonTypeForJavaObject(Object javaObject) {
|
public static PythonType getPythonTypeForJavaObject(Object javaObject) {
|
||||||
for (PythonType pt : get()) {
|
for (PythonType pt : get()) {
|
||||||
if (pt.accepts(javaObject)) {
|
if (pt.accepts(javaObject)) {
|
||||||
|
@ -66,7 +80,7 @@ public class PythonTypes {
|
||||||
throw new PythonException("Unable to find python type for java type: " + javaObject.getClass());
|
throw new PythonException("Unable to find python type for java type: " + javaObject.getClass());
|
||||||
}
|
}
|
||||||
|
|
||||||
public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) {
|
public static <T> PythonType<T> getPythonTypeForPythonObject(PythonObject pythonObject) {
|
||||||
PyObject pyType = PyObject_Type(pythonObject.getNativePythonObject());
|
PyObject pyType = PyObject_Type(pythonObject.getNativePythonObject());
|
||||||
try {
|
try {
|
||||||
String pyTypeStr = PythonTypes.STR.toJava(new PythonObject(pyType, false));
|
String pyTypeStr = PythonTypes.STR.toJava(new PythonObject(pyType, false));
|
||||||
|
@ -75,6 +89,14 @@ public class PythonTypes {
|
||||||
String pyTypeStr2 = "<class '" + pt.getName() + "'>";
|
String pyTypeStr2 = "<class '" + pt.getName() + "'>";
|
||||||
if (pyTypeStr.equals(pyTypeStr2)) {
|
if (pyTypeStr.equals(pyTypeStr2)) {
|
||||||
return pt;
|
return pt;
|
||||||
|
} else {
|
||||||
|
try (PythonGC gc = PythonGC.watch()) {
|
||||||
|
PythonObject pyType2 = pt.pythonType();
|
||||||
|
if (pyType2 != null && Python.isinstance(pythonObject, pyType2)) {
|
||||||
|
return pt;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
throw new PythonException("Unable to find converter for python object of type " + pyTypeStr);
|
throw new PythonException("Unable to find converter for python object of type " + pyTypeStr);
|
||||||
|
@ -212,12 +234,53 @@ public class PythonTypes {
|
||||||
|
|
||||||
public static final PythonType<List> LIST = new PythonType<List>("list", List.class) {
|
public static final PythonType<List> LIST = new PythonType<List>("list", List.class) {
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean accepts(Object javaObject) {
|
||||||
|
return (javaObject instanceof List || javaObject.getClass().isArray());
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List adapt(Object javaObject) {
|
public List adapt(Object javaObject) {
|
||||||
if (javaObject instanceof List) {
|
if (javaObject instanceof List) {
|
||||||
return (List) javaObject;
|
return (List) javaObject;
|
||||||
} else if (javaObject instanceof Object[]) {
|
} else if (javaObject.getClass().isArray()) {
|
||||||
return Arrays.asList((Object[]) javaObject);
|
List<Object> ret = new ArrayList<>();
|
||||||
|
if (javaObject instanceof Object[]) {
|
||||||
|
Object[] arr = (Object[]) javaObject;
|
||||||
|
return new ArrayList<>(Arrays.asList(arr));
|
||||||
|
} else if (javaObject instanceof short[]) {
|
||||||
|
short[] arr = (short[]) javaObject;
|
||||||
|
for (short x : arr) ret.add(x);
|
||||||
|
return ret;
|
||||||
|
} else if (javaObject instanceof int[]) {
|
||||||
|
int[] arr = (int[]) javaObject;
|
||||||
|
for (int x : arr) ret.add(x);
|
||||||
|
return ret;
|
||||||
|
}else if (javaObject instanceof byte[]){
|
||||||
|
byte[] arr = (byte[]) javaObject;
|
||||||
|
for (int x : arr) ret.add(x);
|
||||||
|
return ret;
|
||||||
|
} else if (javaObject instanceof long[]) {
|
||||||
|
long[] arr = (long[]) javaObject;
|
||||||
|
for (long x : arr) ret.add(x);
|
||||||
|
return ret;
|
||||||
|
} else if (javaObject instanceof float[]) {
|
||||||
|
float[] arr = (float[]) javaObject;
|
||||||
|
for (float x : arr) ret.add(x);
|
||||||
|
return ret;
|
||||||
|
} else if (javaObject instanceof double[]) {
|
||||||
|
double[] arr = (double[]) javaObject;
|
||||||
|
for (double x : arr) ret.add(x);
|
||||||
|
return ret;
|
||||||
|
} else if (javaObject instanceof boolean[]) {
|
||||||
|
boolean[] arr = (boolean[]) javaObject;
|
||||||
|
for (boolean x : arr) ret.add(x);
|
||||||
|
return ret;
|
||||||
|
} else {
|
||||||
|
throw new PythonException("Unsupported array type: " + javaObject.getClass().toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to List");
|
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to List");
|
||||||
}
|
}
|
||||||
|
@ -327,7 +390,13 @@ public class PythonTypes {
|
||||||
}
|
}
|
||||||
Object v = javaObject.get(k);
|
Object v = javaObject.get(k);
|
||||||
PythonObject pyVal;
|
PythonObject pyVal;
|
||||||
pyVal = PythonTypes.convert(v);
|
if (v instanceof PythonObject) {
|
||||||
|
pyVal = (PythonObject) v;
|
||||||
|
} else if (v instanceof PyObject) {
|
||||||
|
pyVal = new PythonObject((PyObject) v);
|
||||||
|
} else {
|
||||||
|
pyVal = PythonTypes.convert(v);
|
||||||
|
}
|
||||||
int errCode = PyDict_SetItem(pyDict, pyKey.getNativePythonObject(), pyVal.getNativePythonObject());
|
int errCode = PyDict_SetItem(pyDict, pyKey.getNativePythonObject(), pyVal.getNativePythonObject());
|
||||||
if (errCode != 0) {
|
if (errCode != 0) {
|
||||||
String keyStr = pyKey.toString();
|
String keyStr = pyKey.toString();
|
||||||
|
@ -341,4 +410,127 @@ public class PythonTypes {
|
||||||
return new PythonObject(pyDict);
|
return new PythonObject(pyDict);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
public static final PythonType<byte[]> BYTES = new PythonType<byte[]>("bytes", byte[].class) {
|
||||||
|
@Override
|
||||||
|
public byte[] toJava(PythonObject pythonObject) {
|
||||||
|
try (PythonGC gc = PythonGC.watch()) {
|
||||||
|
if (!(Python.isinstance(pythonObject, Python.bytesType()))) {
|
||||||
|
throw new PythonException("Expected bytes. Received: " + pythonObject);
|
||||||
|
}
|
||||||
|
PythonObject pySize = Python.len(pythonObject);
|
||||||
|
byte[] ret = new byte[pySize.toInt()];
|
||||||
|
for (int i = 0; i < ret.length; i++) {
|
||||||
|
ret[i] = (byte)pythonObject.get(i).toInt();
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public PythonObject toPython(byte[] javaObject) {
|
||||||
|
try(PythonGC gc = PythonGC.watch()){
|
||||||
|
PythonObject ret = Python.bytes(LIST.toPython(LIST.adapt(javaObject)));
|
||||||
|
PythonGC.keep(ret);
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public boolean accepts(Object javaObject) {
|
||||||
|
return javaObject instanceof byte[];
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public byte[] adapt(Object javaObject) {
|
||||||
|
if (javaObject instanceof byte[]){
|
||||||
|
return (byte[])javaObject;
|
||||||
|
}
|
||||||
|
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to byte[]");
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Crashes on Adopt OpenJDK
|
||||||
|
* Use implementation in python4j-numpy instead for zero-copy byte buffers.
|
||||||
|
*/
|
||||||
|
// public static final PythonType<BytePointer> MEMORYVIEW = new PythonType<BytePointer>("memoryview", BytePointer.class) {
|
||||||
|
// @Override
|
||||||
|
// public BytePointer toJava(PythonObject pythonObject) {
|
||||||
|
// try (PythonGC gc = PythonGC.watch()) {
|
||||||
|
// if (!(Python.isinstance(pythonObject, Python.memoryviewType()))) {
|
||||||
|
// throw new PythonException("Expected memoryview. Received: " + pythonObject);
|
||||||
|
// }
|
||||||
|
// PythonObject pySize = Python.len(pythonObject);
|
||||||
|
// PythonObject ctypes = Python.importModule("ctypes");
|
||||||
|
// PythonObject charType = ctypes.attr("c_char");
|
||||||
|
// PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(),
|
||||||
|
// pySize.getNativePythonObject()));
|
||||||
|
// PythonObject fromBuffer = charArrayType.attr("from_buffer");
|
||||||
|
// if (pythonObject.attr("readonly").toBoolean()) {
|
||||||
|
// pythonObject = Python.bytearray(pythonObject);
|
||||||
|
// }
|
||||||
|
// PythonObject arr = fromBuffer.call(pythonObject);
|
||||||
|
// PythonObject cast = ctypes.attr("cast");
|
||||||
|
// PythonObject voidPtrType = ctypes.attr("c_void_p");
|
||||||
|
// PythonObject voidPtr = cast.call(arr, voidPtrType);
|
||||||
|
// long address = voidPtr.attr("value").toLong();
|
||||||
|
// long size = pySize.toLong();
|
||||||
|
// try {
|
||||||
|
// Field addressField = Buffer.class.getDeclaredField("address");
|
||||||
|
// addressField.setAccessible(true);
|
||||||
|
// Field capacityField = Buffer.class.getDeclaredField("capacity");
|
||||||
|
// capacityField.setAccessible(true);
|
||||||
|
// ByteBuffer buff = ByteBuffer.allocateDirect(0).order(ByteOrder.nativeOrder());
|
||||||
|
// addressField.setLong(buff, address);
|
||||||
|
// capacityField.setInt(buff, (int) size);
|
||||||
|
// BytePointer ret = new BytePointer(buff);
|
||||||
|
// ret.limit(size);
|
||||||
|
// return ret;
|
||||||
|
//
|
||||||
|
// } catch (Exception e) {
|
||||||
|
// throw new RuntimeException(e);
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// @Override
|
||||||
|
// public PythonObject toPython(BytePointer javaObject) {
|
||||||
|
// long address = javaObject.address();
|
||||||
|
// long size = javaObject.limit();
|
||||||
|
// try (PythonGC gc = PythonGC.watch()) {
|
||||||
|
// PythonObject ctypes = Python.importModule("ctypes");
|
||||||
|
// PythonObject charType = ctypes.attr("c_char");
|
||||||
|
// PythonObject pySize = new PythonObject(size);
|
||||||
|
// PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(),
|
||||||
|
// pySize.getNativePythonObject()));
|
||||||
|
// PythonObject fromAddress = charArrayType.attr("from_address");
|
||||||
|
// PythonObject arr = fromAddress.call(new PythonObject(address));
|
||||||
|
// PythonObject memoryView = Python.memoryview(arr).attr("cast").call("b");
|
||||||
|
// PythonGC.keep(memoryView);
|
||||||
|
// return memoryView;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// @Override
|
||||||
|
// public boolean accepts(Object javaObject) {
|
||||||
|
// return javaObject instanceof Pointer || javaObject instanceof DirectBuffer;
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// @Override
|
||||||
|
// public BytePointer adapt(Object javaObject) {
|
||||||
|
// if (javaObject instanceof BytePointer) {
|
||||||
|
// return (BytePointer) javaObject;
|
||||||
|
// } else if (javaObject instanceof Pointer) {
|
||||||
|
// return new BytePointer((Pointer) javaObject);
|
||||||
|
// } else if (javaObject instanceof DirectBuffer) {
|
||||||
|
// return new BytePointer((ByteBuffer) javaObject);
|
||||||
|
// } else {
|
||||||
|
// throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to BytePointer");
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// };
|
||||||
|
|
||||||
}
|
}
|
|
@ -14,7 +14,7 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.eclipse.python4j;
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
@lombok.Data
|
@lombok.Data
|
||||||
public class PythonVariable<T> {
|
public class PythonVariable<T> {
|
|
@ -0,0 +1,47 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Some syntax sugar for lookup by name
|
||||||
|
*/
|
||||||
|
public class PythonVariables extends ArrayList<PythonVariable> {
|
||||||
|
public PythonVariable get(String variableName) {
|
||||||
|
for (PythonVariable pyVar: this){
|
||||||
|
if (pyVar.getName().equals(variableName)){
|
||||||
|
return pyVar;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public <T> boolean add(String variableName, PythonType<T> variableType, Object value){
|
||||||
|
return this.add(new PythonVariable<>(variableName, variableType, value));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonVariables(PythonVariable... variables){
|
||||||
|
this(Arrays.asList(variables));
|
||||||
|
}
|
||||||
|
public PythonVariables(List<PythonVariable> list){
|
||||||
|
super();
|
||||||
|
addAll(list);
|
||||||
|
}
|
||||||
|
}
|
|
@ -15,9 +15,12 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
import org.eclipse.python4j.*;
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.python4j.PythonContextManager;
|
||||||
|
import org.nd4j.python4j.PythonExecutioner;
|
||||||
|
import org.nd4j.python4j.PythonTypes;
|
||||||
|
import org.nd4j.python4j.PythonVariable;
|
||||||
|
|
||||||
import javax.annotation.concurrent.NotThreadSafe;
|
import javax.annotation.concurrent.NotThreadSafe;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
|
@ -15,9 +15,9 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
import org.eclipse.python4j.PythonException;
|
import org.nd4j.python4j.PythonException;
|
||||||
import org.eclipse.python4j.PythonObject;
|
import org.nd4j.python4j.PythonObject;
|
||||||
import org.eclipse.python4j.PythonTypes;
|
import org.nd4j.python4j.PythonTypes;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
|
|
@ -16,9 +16,9 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
import org.eclipse.python4j.Python;
|
import org.nd4j.python4j.Python;
|
||||||
import org.eclipse.python4j.PythonContextManager;
|
import org.nd4j.python4j.PythonContextManager;
|
||||||
import org.eclipse.python4j.PythonExecutioner;
|
import org.nd4j.python4j.PythonExecutioner;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import javax.annotation.concurrent.NotThreadSafe;
|
import javax.annotation.concurrent.NotThreadSafe;
|
||||||
|
|
|
@ -14,9 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
import org.eclipse.python4j.Python;
|
import org.nd4j.python4j.Python;
|
||||||
import org.eclipse.python4j.PythonGC;
|
import org.nd4j.python4j.PythonGC;
|
||||||
import org.eclipse.python4j.PythonObject;
|
import org.nd4j.python4j.PythonObject;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
@ -49,6 +49,6 @@ public class PythonGCTest {
|
||||||
PythonObject pyObjCount3 = Python.len(getObjects.call());
|
PythonObject pyObjCount3 = Python.len(getObjects.call());
|
||||||
long objCount3 = pyObjCount3.toLong();
|
long objCount3 = pyObjCount3.toLong();
|
||||||
diff = objCount3 - objCount2;
|
diff = objCount3 - objCount2;
|
||||||
Assert.assertEquals(2, diff);// 2 objects created during function call
|
Assert.assertTrue(diff <= 2);// 2 objects created during function call
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,10 +14,10 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
import org.eclipse.python4j.PythonContextManager;
|
import org.nd4j.python4j.PythonContextManager;
|
||||||
import org.eclipse.python4j.PythonJob;
|
import org.nd4j.python4j.PythonJob;
|
||||||
import org.eclipse.python4j.PythonTypes;
|
import org.nd4j.python4j.PythonTypes;
|
||||||
import org.eclipse.python4j.PythonVariable;
|
import org.nd4j.python4j.PythonVariable;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
|
@ -30,7 +30,7 @@ import static org.junit.Assert.assertEquals;
|
||||||
public class PythonJobTest {
|
public class PythonJobTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPythonJobBasic() throws Exception{
|
public void testPythonJobBasic(){
|
||||||
PythonContextManager.deleteNonMainContexts();
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
String code = "c = a + b";
|
String code = "c = a + b";
|
||||||
|
@ -65,7 +65,7 @@ public class PythonJobTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPythonJobReturnAllVariables()throws Exception{
|
public void testPythonJobReturnAllVariables(){
|
||||||
PythonContextManager.deleteNonMainContexts();
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
String code = "c = a + b";
|
String code = "c = a + b";
|
||||||
|
@ -101,7 +101,7 @@ public class PythonJobTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultiplePythonJobsParallel()throws Exception{
|
public void testMultiplePythonJobsParallel(){
|
||||||
PythonContextManager.deleteNonMainContexts();
|
PythonContextManager.deleteNonMainContexts();
|
||||||
String code1 = "c = a + b";
|
String code1 = "c = a + b";
|
||||||
PythonJob job1 = new PythonJob("job1", code1, false);
|
PythonJob job1 = new PythonJob("job1", code1, false);
|
||||||
|
@ -150,7 +150,7 @@ public class PythonJobTest {
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testPythonJobSetupRun()throws Exception{
|
public void testPythonJobSetupRun(){
|
||||||
|
|
||||||
PythonContextManager.deleteNonMainContexts();
|
PythonContextManager.deleteNonMainContexts();
|
||||||
String code = "five=None\n" +
|
String code = "five=None\n" +
|
||||||
|
@ -189,7 +189,7 @@ public class PythonJobTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
@Test
|
@Test
|
||||||
public void testPythonJobSetupRunAndReturnAllVariables()throws Exception{
|
public void testPythonJobSetupRunAndReturnAllVariables(){
|
||||||
PythonContextManager.deleteNonMainContexts();
|
PythonContextManager.deleteNonMainContexts();
|
||||||
String code = "five=None\n" +
|
String code = "five=None\n" +
|
||||||
"c=None\n"+
|
"c=None\n"+
|
||||||
|
@ -225,7 +225,7 @@ public class PythonJobTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testMultiplePythonJobsSetupRunParallel()throws Exception{
|
public void testMultiplePythonJobsSetupRunParallel(){
|
||||||
PythonContextManager.deleteNonMainContexts();
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
String code1 = "five=None\n" +
|
String code1 = "five=None\n" +
|
||||||
|
|
|
@ -14,10 +14,9 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
import org.eclipse.python4j.*;
|
import org.nd4j.python4j.*;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
import javax.annotation.concurrent.NotThreadSafe;
|
import javax.annotation.concurrent.NotThreadSafe;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
import org.eclipse.python4j.PythonException;
|
import org.nd4j.python4j.*;
|
||||||
import org.eclipse.python4j.PythonObject;
|
|
||||||
import org.eclipse.python4j.PythonTypes;
|
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
public class PythonPrimitiveTypesTest {
|
public class PythonPrimitiveTypesTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -78,5 +79,18 @@ public class PythonPrimitiveTypesTest {
|
||||||
|
|
||||||
Assert.assertEquals(b, b3);
|
Assert.assertEquals(b, b3);
|
||||||
}
|
}
|
||||||
|
@Test
|
||||||
|
public void testBytes() {
|
||||||
|
byte[] bytes = new byte[]{97, 98, 99};
|
||||||
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
|
inputs.add(new PythonVariable<>("buff", PythonTypes.BYTES, bytes));
|
||||||
|
List<PythonVariable> outputs = new ArrayList<>();
|
||||||
|
outputs.add(new PythonVariable<>("s1", PythonTypes.STR));
|
||||||
|
outputs.add(new PythonVariable<>("buff2", PythonTypes.BYTES));
|
||||||
|
String code = "s1 = ''.join(chr(c) for c in buff)\nbuff2=b'def'";
|
||||||
|
PythonExecutioner.exec(code, inputs, outputs);
|
||||||
|
Assert.assertEquals("abc", outputs.get(0).getValue());
|
||||||
|
Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[])outputs.get(1).getValue());
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
<parent>
|
<parent>
|
||||||
<artifactId>python4j-parent</artifactId>
|
<artifactId>python4j-parent</artifactId>
|
||||||
<groupId>org.eclipse</groupId>
|
<groupId>org.nd4j</groupId>
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
</parent>
|
</parent>
|
||||||
<modelVersion>4.0.0</modelVersion>
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
@ -28,15 +28,50 @@
|
||||||
<version>${nd4j.version}</version>
|
<version>${nd4j.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>python4j-core</artifactId>
|
||||||
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-native</id>
|
<id>test-nd4j-native</id>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-native</artifactId>
|
||||||
|
<version>${nd4j.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>dl4j-test-resources</artifactId>
|
||||||
|
<version>${nd4j.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
|
|
||||||
<profile>
|
<profile>
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
<id>test-nd4j-cuda-10.2</id>
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-cuda-10.1</artifactId>
|
||||||
|
<version>${nd4j.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.deeplearning4j</groupId>
|
||||||
|
<artifactId>dl4j-test-resources</artifactId>
|
||||||
|
<version>${nd4j.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
</profile>
|
</profile>
|
||||||
</profiles>
|
</profiles>
|
||||||
|
|
||||||
|
|
||||||
</project>
|
</project>
|
|
@ -0,0 +1,299 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
package org.nd4j.python4j;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.bytedeco.cpython.PyObject;
|
||||||
|
import org.bytedeco.cpython.PyTypeObject;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.bytedeco.javacpp.SizeTPointer;
|
||||||
|
import org.bytedeco.numpy.PyArrayObject;
|
||||||
|
import org.bytedeco.numpy.global.numpy;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.concurrency.AffinityManager;
|
||||||
|
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.shape.Shape;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.*;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
|
import static org.bytedeco.cpython.global.python.*;
|
||||||
|
import static org.bytedeco.cpython.global.python.Py_DecRef;
|
||||||
|
import static org.bytedeco.numpy.global.numpy.*;
|
||||||
|
import static org.bytedeco.numpy.global.numpy.NPY_ARRAY_CARRAY;
|
||||||
|
import static org.bytedeco.numpy.global.numpy.PyArray_Type;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class NumpyArray extends PythonType<INDArray> {
|
||||||
|
|
||||||
|
public static final NumpyArray INSTANCE;
|
||||||
|
private static final AtomicBoolean init = new AtomicBoolean(false);
|
||||||
|
private static final Map<String, DataBuffer> cache = new HashMap<>();
|
||||||
|
|
||||||
|
static {
|
||||||
|
new PythonExecutioner();
|
||||||
|
INSTANCE = new NumpyArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public File[] packages(){
|
||||||
|
try{
|
||||||
|
return new File[]{numpy.cachePackage()};
|
||||||
|
}catch(Exception e){
|
||||||
|
throw new PythonException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public synchronized void init() {
|
||||||
|
if (init.get()) return;
|
||||||
|
init.set(true);
|
||||||
|
if (PythonGIL.locked()) {
|
||||||
|
throw new PythonException("Can not initialize numpy - GIL already acquired.");
|
||||||
|
}
|
||||||
|
int err = numpy._import_array();
|
||||||
|
if (err < 0){
|
||||||
|
System.out.println("Numpy import failed!");
|
||||||
|
throw new PythonException("Numpy import failed!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public NumpyArray() {
|
||||||
|
super("numpy.ndarray", INDArray.class);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray toJava(PythonObject pythonObject) {
|
||||||
|
log.info("Converting PythonObject to INDArray...");
|
||||||
|
PyObject np = PyImport_ImportModule("numpy");
|
||||||
|
PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
|
||||||
|
if (PyObject_IsInstance(pythonObject.getNativePythonObject(), ndarray) != 1) {
|
||||||
|
Py_DecRef(ndarray);
|
||||||
|
Py_DecRef(np);
|
||||||
|
throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
|
||||||
|
}
|
||||||
|
Py_DecRef(ndarray);
|
||||||
|
Py_DecRef(np);
|
||||||
|
PyArrayObject npArr = new PyArrayObject(pythonObject.getNativePythonObject());
|
||||||
|
long[] shape = new long[PyArray_NDIM(npArr)];
|
||||||
|
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
|
||||||
|
if (shapePtr != null)
|
||||||
|
shapePtr.get(shape, 0, shape.length);
|
||||||
|
long[] strides = new long[shape.length];
|
||||||
|
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
|
||||||
|
if (stridesPtr != null)
|
||||||
|
stridesPtr.get(strides, 0, strides.length);
|
||||||
|
int npdtype = PyArray_TYPE(npArr);
|
||||||
|
|
||||||
|
DataType dtype;
|
||||||
|
switch (npdtype) {
|
||||||
|
case NPY_DOUBLE:
|
||||||
|
dtype = DataType.DOUBLE;
|
||||||
|
break;
|
||||||
|
case NPY_FLOAT:
|
||||||
|
dtype = DataType.FLOAT;
|
||||||
|
break;
|
||||||
|
case NPY_SHORT:
|
||||||
|
dtype = DataType.SHORT;
|
||||||
|
break;
|
||||||
|
case NPY_INT:
|
||||||
|
dtype = DataType.INT32;
|
||||||
|
break;
|
||||||
|
case NPY_LONG:
|
||||||
|
dtype = DataType.INT64;
|
||||||
|
break;
|
||||||
|
case NPY_UINT:
|
||||||
|
dtype = DataType.UINT32;
|
||||||
|
break;
|
||||||
|
case NPY_BYTE:
|
||||||
|
dtype = DataType.INT8;
|
||||||
|
break;
|
||||||
|
case NPY_UBYTE:
|
||||||
|
dtype = DataType.UINT8;
|
||||||
|
break;
|
||||||
|
case NPY_BOOL:
|
||||||
|
dtype = DataType.BOOL;
|
||||||
|
break;
|
||||||
|
case NPY_HALF:
|
||||||
|
dtype = DataType.FLOAT16;
|
||||||
|
break;
|
||||||
|
case NPY_LONGLONG:
|
||||||
|
dtype = DataType.INT64;
|
||||||
|
break;
|
||||||
|
case NPY_USHORT:
|
||||||
|
dtype = DataType.UINT16;
|
||||||
|
break;
|
||||||
|
case NPY_ULONG:
|
||||||
|
case NPY_ULONGLONG:
|
||||||
|
dtype = DataType.UINT64;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new PythonException("Unsupported array data type: " + npdtype);
|
||||||
|
}
|
||||||
|
long size = 1;
|
||||||
|
for (int i = 0; i < shape.length; size *= shape[i++]) ;
|
||||||
|
|
||||||
|
INDArray ret;
|
||||||
|
long address = PyArray_DATA(npArr).address();
|
||||||
|
String key = address + "_" + size + "_" + dtype;
|
||||||
|
DataBuffer buff = cache.get(key);
|
||||||
|
if (buff == null) {
|
||||||
|
try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||||
|
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address);
|
||||||
|
ptr = ptr.limit(size);
|
||||||
|
ptr = ptr.capacity(size);
|
||||||
|
buff = Nd4j.createBuffer(ptr, size, dtype);
|
||||||
|
cache.put(key, buff);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int elemSize = buff.getElementSize();
|
||||||
|
long[] nd4jStrides = new long[strides.length];
|
||||||
|
for (int i = 0; i < strides.length; i++) {
|
||||||
|
nd4jStrides[i] = strides[i] / elemSize;
|
||||||
|
}
|
||||||
|
ret = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype);
|
||||||
|
Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.HOST);
|
||||||
|
log.info("Done.");
|
||||||
|
return ret;
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public PythonObject toPython(INDArray indArray) {
|
||||||
|
log.info("Converting INDArray to PythonObject...");
|
||||||
|
DataType dataType = indArray.dataType();
|
||||||
|
DataBuffer buff = indArray.data();
|
||||||
|
String key = buff.pointer().address() + "_" + buff.length() + "_" + dataType;
|
||||||
|
cache.put(key, buff);
|
||||||
|
int numpyType;
|
||||||
|
String ctype;
|
||||||
|
switch (dataType) {
|
||||||
|
case DOUBLE:
|
||||||
|
numpyType = NPY_DOUBLE;
|
||||||
|
ctype = "c_double";
|
||||||
|
break;
|
||||||
|
case FLOAT:
|
||||||
|
case BFLOAT16:
|
||||||
|
numpyType = NPY_FLOAT;
|
||||||
|
ctype = "c_float";
|
||||||
|
break;
|
||||||
|
case SHORT:
|
||||||
|
numpyType = NPY_SHORT;
|
||||||
|
ctype = "c_short";
|
||||||
|
break;
|
||||||
|
case INT:
|
||||||
|
numpyType = NPY_INT;
|
||||||
|
ctype = "c_int";
|
||||||
|
break;
|
||||||
|
case LONG:
|
||||||
|
numpyType = NPY_INT64;
|
||||||
|
ctype = "c_int64";
|
||||||
|
break;
|
||||||
|
case UINT16:
|
||||||
|
numpyType = NPY_USHORT;
|
||||||
|
ctype = "c_uint16";
|
||||||
|
break;
|
||||||
|
case UINT32:
|
||||||
|
numpyType = NPY_UINT;
|
||||||
|
ctype = "c_uint";
|
||||||
|
break;
|
||||||
|
case UINT64:
|
||||||
|
numpyType = NPY_UINT64;
|
||||||
|
ctype = "c_uint64";
|
||||||
|
break;
|
||||||
|
case BOOL:
|
||||||
|
numpyType = NPY_BOOL;
|
||||||
|
ctype = "c_bool";
|
||||||
|
break;
|
||||||
|
case BYTE:
|
||||||
|
numpyType = NPY_BYTE;
|
||||||
|
ctype = "c_byte";
|
||||||
|
break;
|
||||||
|
case UBYTE:
|
||||||
|
numpyType = NPY_UBYTE;
|
||||||
|
ctype = "c_ubyte";
|
||||||
|
break;
|
||||||
|
case HALF:
|
||||||
|
numpyType = NPY_HALF;
|
||||||
|
ctype = "c_short";
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Unsupported dtype: " + dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
long[] shape = indArray.shape();
|
||||||
|
INDArray inputArray = indArray;
|
||||||
|
if (dataType == DataType.BFLOAT16) {
|
||||||
|
log.warn("Creating copy of array as bfloat16 is not supported by numpy.");
|
||||||
|
inputArray = indArray.castTo(DataType.FLOAT);
|
||||||
|
}
|
||||||
|
|
||||||
|
//Sync to host memory in the case of CUDA, before passing the host memory pointer to Python
|
||||||
|
|
||||||
|
Nd4j.getAffinityManager().ensureLocation(inputArray, AffinityManager.Location.HOST);
|
||||||
|
|
||||||
|
// PyArray_Type() call causes jvm crash in linux cpu if GIL is acquired by non main thread.
|
||||||
|
// Using Interpreter for now:
|
||||||
|
|
||||||
|
// try(PythonContextManager.Context context = new PythonContextManager.Context("__np_array_converter")){
|
||||||
|
// log.info("Stringing exec...");
|
||||||
|
// String code = "import ctypes\nimport numpy as np\n" +
|
||||||
|
// "cArr = (ctypes." + ctype + "*" + indArray.length() + ")"+
|
||||||
|
// ".from_address(" + indArray.data().pointer().address() + ")\n"+
|
||||||
|
// "npArr = np.frombuffer(cArr, dtype=" + ((numpyType == NPY_HALF) ? "'half'" : "ctypes." + ctype)+
|
||||||
|
// ").reshape(" + Arrays.toString(indArray.shape()) + ")";
|
||||||
|
// PythonExecutioner.exec(code);
|
||||||
|
// log.info("exec done.");
|
||||||
|
// PythonObject ret = PythonExecutioner.getVariable("npArr");
|
||||||
|
// Py_IncRef(ret.getNativePythonObject());
|
||||||
|
// return ret;
|
||||||
|
//
|
||||||
|
// }
|
||||||
|
log.info("NUMPY: PyArray_Type()");
|
||||||
|
PyTypeObject pyTypeObject = PyArray_Type();
|
||||||
|
|
||||||
|
|
||||||
|
log.info("NUMPY: PyArray_New()");
|
||||||
|
PyObject npArr = PyArray_New(pyTypeObject, shape.length, new SizeTPointer(shape),
|
||||||
|
numpyType, null,
|
||||||
|
inputArray.data().addressPointer(),
|
||||||
|
0, NPY_ARRAY_CARRAY, null);
|
||||||
|
log.info("Done.");
|
||||||
|
return new PythonObject(npArr);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean accepts(Object javaObject) {
|
||||||
|
return javaObject instanceof INDArray;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray adapt(Object javaObject) {
|
||||||
|
if (javaObject instanceof INDArray) {
|
||||||
|
return (INDArray) javaObject;
|
||||||
|
}
|
||||||
|
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to INDArray");
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1 @@
|
||||||
|
org.nd4j.python4j.NumpyArray
|
|
@ -0,0 +1,169 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
import org.nd4j.python4j.*;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.nativeblas.OpaqueDataBuffer;
|
||||||
|
|
||||||
|
import javax.annotation.concurrent.NotThreadSafe;
|
||||||
|
import java.lang.reflect.Method;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collection;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
@NotThreadSafe
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
|
public class PythonNumpyBasicTest {
|
||||||
|
private DataType dataType;
|
||||||
|
private long[] shape;
|
||||||
|
|
||||||
|
public PythonNumpyBasicTest(DataType dataType, long[] shape, String dummyArg) {
|
||||||
|
this.dataType = dataType;
|
||||||
|
this.shape = shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}, shape={2}")
|
||||||
|
public static Collection params() {
|
||||||
|
DataType[] types = new DataType[] {
|
||||||
|
DataType.BOOL,
|
||||||
|
DataType.FLOAT16,
|
||||||
|
DataType.BFLOAT16,
|
||||||
|
DataType.FLOAT,
|
||||||
|
DataType.DOUBLE,
|
||||||
|
DataType.INT8,
|
||||||
|
DataType.INT16,
|
||||||
|
DataType.INT32,
|
||||||
|
DataType.INT64,
|
||||||
|
DataType.UINT8,
|
||||||
|
DataType.UINT16,
|
||||||
|
DataType.UINT32,
|
||||||
|
DataType.UINT64
|
||||||
|
};
|
||||||
|
|
||||||
|
long[][] shapes = new long[][]{
|
||||||
|
new long[]{2, 3},
|
||||||
|
new long[]{3},
|
||||||
|
new long[]{1},
|
||||||
|
new long[]{} // scalar
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
List<Object[]> ret = new ArrayList<>();
|
||||||
|
for (DataType type: types){
|
||||||
|
for (long[] shape: shapes){
|
||||||
|
ret.add(new Object[]{type, shape, Arrays.toString(shape)});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testConversion(){
|
||||||
|
INDArray arr = Nd4j.zeros(dataType, shape);
|
||||||
|
PythonObject npArr = PythonTypes.convert(arr);
|
||||||
|
INDArray arr2 = PythonTypes.<INDArray>getPythonTypeForPythonObject(npArr).toJava(npArr);
|
||||||
|
if (dataType == DataType.BFLOAT16){
|
||||||
|
arr = arr.castTo(DataType.FLOAT);
|
||||||
|
}
|
||||||
|
Assert.assertEquals(arr,arr2);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testExecution(){
|
||||||
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
|
INDArray x = Nd4j.ones(dataType, shape);
|
||||||
|
INDArray y = Nd4j.zeros(dataType, shape);
|
||||||
|
INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2));
|
||||||
|
z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z;
|
||||||
|
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
|
||||||
|
inputs.add(new PythonVariable<>("x", arrType, x));
|
||||||
|
inputs.add(new PythonVariable<>("y", arrType, y));
|
||||||
|
List<PythonVariable> outputs = new ArrayList<>();
|
||||||
|
PythonVariable<INDArray> output = new PythonVariable<>("z", arrType);
|
||||||
|
outputs.add(output);
|
||||||
|
String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)";
|
||||||
|
if (shape.length == 0){ // scalar special case
|
||||||
|
code += "\nimport numpy as np\nz = np.asarray(float(z), dtype=x.dtype)";
|
||||||
|
}
|
||||||
|
PythonExecutioner.exec(code, inputs, outputs);
|
||||||
|
INDArray z2 = output.getValue();
|
||||||
|
|
||||||
|
Assert.assertEquals(z.dataType(), z2.dataType());
|
||||||
|
Assert.assertEquals(z, z2);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInplaceExecution(){
|
||||||
|
if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return;
|
||||||
|
if (shape.length == 0) return;
|
||||||
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
|
INDArray x = Nd4j.ones(dataType, shape);
|
||||||
|
INDArray y = Nd4j.zeros(dataType, shape);
|
||||||
|
INDArray z = x.mul(y.add(2));
|
||||||
|
// Nd4j.getAffinityManager().ensureLocation(z, AffinityManager.Location.HOST);
|
||||||
|
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
|
||||||
|
inputs.add(new PythonVariable<>("x", arrType, x));
|
||||||
|
inputs.add(new PythonVariable<>("y", arrType, y));
|
||||||
|
List<PythonVariable> outputs = new ArrayList<>();
|
||||||
|
PythonVariable<INDArray> output = new PythonVariable<>("x", arrType);
|
||||||
|
outputs.add(output);
|
||||||
|
String code = "x *= y + 2";
|
||||||
|
PythonExecutioner.exec(code, inputs, outputs);
|
||||||
|
INDArray z2 = output.getValue();
|
||||||
|
Assert.assertEquals(x.dataType(), z2.dataType());
|
||||||
|
Assert.assertEquals(z.dataType(), z2.dataType());
|
||||||
|
Assert.assertEquals(x, z2);
|
||||||
|
Assert.assertEquals(z, z2);
|
||||||
|
Assert.assertEquals(x.data().pointer().address(), z2.data().pointer().address());
|
||||||
|
if("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
|
||||||
|
Assert.assertEquals(getDeviceAddress(x), getDeviceAddress(z2));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
private static long getDeviceAddress(INDArray array){
|
||||||
|
if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
|
||||||
|
throw new IllegalStateException("Cannot ge device pointer for non-CUDA device");
|
||||||
|
}
|
||||||
|
|
||||||
|
//Use reflection here as OpaqueDataBuffer is only available on BaseCudaDataBuffer and BaseCpuDataBuffer - not DataBuffer/BaseDataBuffer
|
||||||
|
// due to it being defined in nd4j-native-api, not nd4j-api
|
||||||
|
try {
|
||||||
|
Class<?> c = Class.forName("org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer");
|
||||||
|
Method m = c.getMethod("getOpaqueDataBuffer");
|
||||||
|
OpaqueDataBuffer db = (OpaqueDataBuffer) m.invoke(array.data());
|
||||||
|
long address = db.specialBuffer().address();
|
||||||
|
return address;
|
||||||
|
} catch (Throwable t){
|
||||||
|
throw new RuntimeException("Error getting OpaqueDataBuffer", t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,96 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
import org.nd4j.python4j.PythonException;
|
||||||
|
import org.nd4j.python4j.PythonObject;
|
||||||
|
import org.nd4j.python4j.PythonTypes;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import javax.annotation.concurrent.NotThreadSafe;
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
|
||||||
|
@NotThreadSafe
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
|
public class PythonNumpyCollectionsTest {
|
||||||
|
private DataType dataType;
|
||||||
|
|
||||||
|
public PythonNumpyCollectionsTest(DataType dataType){
|
||||||
|
this.dataType = dataType;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
|
||||||
|
public static DataType[] params() {
|
||||||
|
return new DataType[]{
|
||||||
|
DataType.BOOL,
|
||||||
|
DataType.FLOAT16,
|
||||||
|
//DataType.BFLOAT16,
|
||||||
|
DataType.FLOAT,
|
||||||
|
DataType.DOUBLE,
|
||||||
|
DataType.INT8,
|
||||||
|
DataType.INT16,
|
||||||
|
DataType.INT32,
|
||||||
|
DataType.INT64,
|
||||||
|
DataType.UINT8,
|
||||||
|
DataType.UINT16,
|
||||||
|
DataType.UINT32,
|
||||||
|
DataType.UINT64
|
||||||
|
};
|
||||||
|
}
|
||||||
|
@Test
|
||||||
|
public void testPythonDictFromMap() throws PythonException {
|
||||||
|
Map map = new HashMap();
|
||||||
|
map.put("a", 1);
|
||||||
|
map.put(1, "a");
|
||||||
|
map.put("arr", Nd4j.ones(dataType, 2, 3));
|
||||||
|
map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType,3,2)));
|
||||||
|
Map innerMap = new HashMap();
|
||||||
|
innerMap.put("b", 2);
|
||||||
|
innerMap.put(2, "b");
|
||||||
|
innerMap.put(5, Nd4j.ones(dataType, 5));
|
||||||
|
map.put("innermap", innerMap);
|
||||||
|
map.put("list2", Arrays.asList(4, "5", innerMap, false, true));
|
||||||
|
PythonObject dict = PythonTypes.convert(map);
|
||||||
|
Map map2 = PythonTypes.DICT.toJava(dict);
|
||||||
|
Assert.assertEquals(map.toString(), map2.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonListFromList() throws PythonException{
|
||||||
|
List<Object> list = new ArrayList<>();
|
||||||
|
list.add(1);
|
||||||
|
list.add("2");
|
||||||
|
list.add(Nd4j.ones(dataType, 2, 3));
|
||||||
|
list.add(Arrays.asList("a",
|
||||||
|
Nd4j.ones(dataType, 1, 2),1.0, 2f, 10, true, false,
|
||||||
|
Nd4j.zeros(dataType, 3, 2)));
|
||||||
|
Map map = new HashMap();
|
||||||
|
map.put("a", 1);
|
||||||
|
map.put(1, "a");
|
||||||
|
map.put(5, Nd4j.ones(dataType,4, 5));
|
||||||
|
map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType, 3, 1)));
|
||||||
|
list.add(map);
|
||||||
|
PythonObject dict = PythonTypes.convert(list);
|
||||||
|
List list2 = PythonTypes.LIST.toJava(dict);
|
||||||
|
Assert.assertEquals(list.toString(), list2.toString());
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,55 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
import org.nd4j.python4j.Python;
|
||||||
|
import org.nd4j.python4j.PythonGC;
|
||||||
|
import org.nd4j.python4j.PythonObject;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import javax.annotation.concurrent.NotThreadSafe;
|
||||||
|
|
||||||
|
|
||||||
|
@NotThreadSafe
|
||||||
|
public class PythonNumpyGCTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testGC(){
|
||||||
|
PythonObject gcModule = Python.importModule("gc");
|
||||||
|
PythonObject getObjects = gcModule.attr("get_objects");
|
||||||
|
PythonObject pyObjCount1 = Python.len(getObjects.call());
|
||||||
|
long objCount1 = pyObjCount1.toLong();
|
||||||
|
PythonObject pyList = Python.list();
|
||||||
|
pyList.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10)));
|
||||||
|
pyList.attr("append").call(1.0);
|
||||||
|
pyList.attr("append").call(true);
|
||||||
|
PythonObject pyObjCount2 = Python.len(getObjects.call());
|
||||||
|
long objCount2 = pyObjCount2.toLong();
|
||||||
|
long diff = objCount2 - objCount1;
|
||||||
|
Assert.assertTrue(diff > 2);
|
||||||
|
try(PythonGC gc = PythonGC.watch()){
|
||||||
|
PythonObject pyList2 = Python.list();
|
||||||
|
pyList2.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10)));
|
||||||
|
pyList2.attr("append").call(1.0);
|
||||||
|
pyList2.attr("append").call(true);
|
||||||
|
}
|
||||||
|
PythonObject pyObjCount3 = Python.len(getObjects.call());
|
||||||
|
long objCount3 = pyObjCount3.toLong();
|
||||||
|
diff = objCount3 - objCount2;
|
||||||
|
Assert.assertTrue(diff <= 2);// 2 objects created during function call
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,22 @@
|
||||||
|
import org.nd4j.python4j.NumpyArray;
|
||||||
|
import org.nd4j.python4j.Python;
|
||||||
|
import org.nd4j.python4j.PythonGC;
|
||||||
|
import org.nd4j.python4j.PythonObject;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
public class PythonNumpyImportTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNumpyImport(){
|
||||||
|
try(PythonGC gc = PythonGC.watch()){
|
||||||
|
PythonObject np = Python.importModule("numpy");
|
||||||
|
PythonObject zeros = np.attr("zeros").call(5);
|
||||||
|
INDArray arr = NumpyArray.INSTANCE.toJava(zeros);
|
||||||
|
Assert.assertEquals(arr, Nd4j.zeros(DataType.DOUBLE, 5));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,303 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.python4j.*;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
|
||||||
|
@javax.annotation.concurrent.NotThreadSafe
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
|
public class PythonNumpyJobTest {
|
||||||
|
private DataType dataType;
|
||||||
|
|
||||||
|
public PythonNumpyJobTest(DataType dataType){
|
||||||
|
this.dataType = dataType;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
|
||||||
|
public static DataType[] params() {
|
||||||
|
return new DataType[]{
|
||||||
|
DataType.BOOL,
|
||||||
|
DataType.FLOAT16,
|
||||||
|
DataType.BFLOAT16,
|
||||||
|
DataType.FLOAT,
|
||||||
|
DataType.DOUBLE,
|
||||||
|
DataType.INT8,
|
||||||
|
DataType.INT16,
|
||||||
|
DataType.INT32,
|
||||||
|
DataType.INT64,
|
||||||
|
DataType.UINT8,
|
||||||
|
DataType.UINT16,
|
||||||
|
DataType.UINT32,
|
||||||
|
DataType.UINT64
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNumpyJobBasic(){
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
|
INDArray x = Nd4j.ones(dataType, 2, 3);
|
||||||
|
INDArray y = Nd4j.zeros(dataType, 2, 3);
|
||||||
|
INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2));
|
||||||
|
z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z;
|
||||||
|
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
|
||||||
|
inputs.add(new PythonVariable<>("x", arrType, x));
|
||||||
|
inputs.add(new PythonVariable<>("y", arrType, y));
|
||||||
|
List<PythonVariable> outputs = new ArrayList<>();
|
||||||
|
PythonVariable<INDArray> output = new PythonVariable<>("z", arrType);
|
||||||
|
outputs.add(output);
|
||||||
|
String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)";
|
||||||
|
|
||||||
|
PythonJob job = new PythonJob("job1", code, false);
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
INDArray z2 = output.getValue();
|
||||||
|
|
||||||
|
if (dataType == DataType.BFLOAT16){
|
||||||
|
z2 = z2.castTo(DataType.FLOAT);
|
||||||
|
}
|
||||||
|
|
||||||
|
Assert.assertEquals(z, z2);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNumpyJobReturnAllVariables(){
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
|
INDArray x = Nd4j.ones(dataType, 2, 3);
|
||||||
|
INDArray y = Nd4j.zeros(dataType, 2, 3);
|
||||||
|
INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2));
|
||||||
|
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
|
||||||
|
inputs.add(new PythonVariable<>("x", arrType, x));
|
||||||
|
inputs.add(new PythonVariable<>("y", arrType, y));
|
||||||
|
String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)";
|
||||||
|
|
||||||
|
PythonJob job = new PythonJob("job1", code, false);
|
||||||
|
List<PythonVariable> outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
INDArray x2 = (INDArray) outputs.get(0).getValue();
|
||||||
|
INDArray y2 = (INDArray) outputs.get(1).getValue();
|
||||||
|
INDArray z2 = (INDArray) outputs.get(2).getValue();
|
||||||
|
|
||||||
|
if (dataType == DataType.BFLOAT16){
|
||||||
|
x = x.castTo(DataType.FLOAT);
|
||||||
|
y = y.castTo(DataType.FLOAT);
|
||||||
|
z = z.castTo(DataType.FLOAT);
|
||||||
|
}
|
||||||
|
Assert.assertEquals(x, x2);
|
||||||
|
Assert.assertEquals(y, y2);
|
||||||
|
Assert.assertEquals(z, z2);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMultipleNumpyJobsParallel(){
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
String code1 =(dataType == DataType.BOOL)?"z = x":"z = x + y";
|
||||||
|
PythonJob job1 = new PythonJob("job1", code1, false);
|
||||||
|
|
||||||
|
String code2 =(dataType == DataType.BOOL)?"z = y":"z = x - y";
|
||||||
|
PythonJob job2 = new PythonJob("job2", code2, false);
|
||||||
|
|
||||||
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
|
INDArray x = Nd4j.ones(dataType, 2, 3);
|
||||||
|
INDArray y = Nd4j.zeros(dataType, 2, 3);
|
||||||
|
INDArray z1 = (dataType == DataType.BOOL)?x:x.add(y);
|
||||||
|
z1 = (dataType == DataType.BFLOAT16)? z1.castTo(DataType.FLOAT): z1;
|
||||||
|
INDArray z2 = (dataType == DataType.BOOL)?y:x.sub(y);
|
||||||
|
z2 = (dataType == DataType.BFLOAT16)? z2.castTo(DataType.FLOAT): z2;
|
||||||
|
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
|
||||||
|
inputs.add(new PythonVariable<>("x", arrType, x));
|
||||||
|
inputs.add(new PythonVariable<>("y", arrType, y));
|
||||||
|
|
||||||
|
|
||||||
|
List<PythonVariable> outputs = new ArrayList<>();
|
||||||
|
|
||||||
|
outputs.add(new PythonVariable<>("z", arrType));
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(z1, outputs.get(0).getValue());
|
||||||
|
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(z2, outputs.get(0).getValue());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public synchronized void testNumpyJobSetupRun(){
|
||||||
|
if (dataType == DataType.BOOL)return;
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
String code = "five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b + five\n"+
|
||||||
|
" return {'c':c}\n\n";
|
||||||
|
|
||||||
|
PythonJob job = new PythonJob("job1", code, true);
|
||||||
|
|
||||||
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
|
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2)));
|
||||||
|
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||||
|
|
||||||
|
List<PythonVariable> outputs = new ArrayList<>();
|
||||||
|
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
|
||||||
|
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10),
|
||||||
|
outputs.get(0).getValue());
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new ArrayList<>();
|
||||||
|
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||||
|
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
|
||||||
|
|
||||||
|
|
||||||
|
outputs = new ArrayList<>();
|
||||||
|
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12),
|
||||||
|
outputs.get(0).getValue());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
@Test
|
||||||
|
public void testNumpyJobSetupRunAndReturnAllVariables(){
|
||||||
|
if (dataType == DataType.BOOL)return;
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
String code = "five=None\n" +
|
||||||
|
"c=None\n"+
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" global c\n" +
|
||||||
|
" c = a + b + five\n";
|
||||||
|
PythonJob job = new PythonJob("job1", code, true);
|
||||||
|
|
||||||
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
|
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2)));
|
||||||
|
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||||
|
|
||||||
|
List<PythonVariable> outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10),
|
||||||
|
outputs.get(1).getValue());
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new ArrayList<>();
|
||||||
|
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||||
|
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
|
||||||
|
|
||||||
|
|
||||||
|
outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
|
||||||
|
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12),
|
||||||
|
outputs.get(1).getValue());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMultipleNumpyJobsSetupRunParallel(){
|
||||||
|
if (dataType == DataType.BOOL)return;
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code1 = "five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b + five\n"+
|
||||||
|
" return {'c':c}\n\n";
|
||||||
|
PythonJob job1 = new PythonJob("job1", code1, true);
|
||||||
|
|
||||||
|
String code2 = "five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b - five\n"+
|
||||||
|
" return {'c':c}\n\n";
|
||||||
|
PythonJob job2 = new PythonJob("job2", code2, true);
|
||||||
|
|
||||||
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
|
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2)));
|
||||||
|
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||||
|
|
||||||
|
|
||||||
|
List<PythonVariable> outputs = new ArrayList<>();
|
||||||
|
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10),
|
||||||
|
outputs.get(0).getValue());
|
||||||
|
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3),
|
||||||
|
outputs.get(0).getValue());
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new ArrayList<>();
|
||||||
|
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||||
|
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
|
||||||
|
|
||||||
|
outputs = new ArrayList<>();
|
||||||
|
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
|
||||||
|
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12),
|
||||||
|
outputs.get(0).getValue());
|
||||||
|
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(2),
|
||||||
|
outputs.get(0).getValue());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,194 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
import org.nd4j.python4j.*;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.junit.runners.Parameterized;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import javax.annotation.concurrent.NotThreadSafe;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
|
@NotThreadSafe
|
||||||
|
@RunWith(Parameterized.class)
|
||||||
|
public class PythonNumpyMultiThreadTest {
|
||||||
|
private DataType dataType;
|
||||||
|
|
||||||
|
public PythonNumpyMultiThreadTest(DataType dataType) {
|
||||||
|
this.dataType = dataType;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
|
||||||
|
public static DataType[] params() {
|
||||||
|
return new DataType[]{
|
||||||
|
// DataType.BOOL,
|
||||||
|
// DataType.FLOAT16,
|
||||||
|
// DataType.BFLOAT16,
|
||||||
|
DataType.FLOAT,
|
||||||
|
DataType.DOUBLE,
|
||||||
|
// DataType.INT8,
|
||||||
|
// DataType.INT16,
|
||||||
|
DataType.INT32,
|
||||||
|
DataType.INT64,
|
||||||
|
// DataType.UINT8,
|
||||||
|
// DataType.UINT16,
|
||||||
|
// DataType.UINT32,
|
||||||
|
// DataType.UINT64
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMultiThreading1() throws Throwable {
|
||||||
|
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
|
||||||
|
Runnable runnable = new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
try (PythonGIL gil = PythonGIL.lock()) {
|
||||||
|
try (PythonGC gc = PythonGC.watch()) {
|
||||||
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
|
inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||||
|
inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
|
||||||
|
PythonVariable out = new PythonVariable<>("z", NumpyArray.INSTANCE);
|
||||||
|
String code = "z = x + y";
|
||||||
|
PythonExecutioner.exec(code, inputs, Collections.singletonList(out));
|
||||||
|
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), out.getValue());
|
||||||
|
}
|
||||||
|
} catch (Throwable e) {
|
||||||
|
exceptions.add(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
int numThreads = 10;
|
||||||
|
Thread[] threads = new Thread[numThreads];
|
||||||
|
for (int i = 0; i < threads.length; i++) {
|
||||||
|
threads[i] = new Thread(runnable);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < threads.length; i++) {
|
||||||
|
threads[i].start();
|
||||||
|
}
|
||||||
|
Thread.sleep(100);
|
||||||
|
for (int i = 0; i < threads.length; i++) {
|
||||||
|
threads[i].join();
|
||||||
|
}
|
||||||
|
if (!exceptions.isEmpty()) {
|
||||||
|
throw (exceptions.get(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMultiThreading2() throws Throwable {
|
||||||
|
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
|
||||||
|
Runnable runnable = new Runnable() {
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
try (PythonGIL gil = PythonGIL.lock()) {
|
||||||
|
try (PythonGC gc = PythonGC.watch()) {
|
||||||
|
PythonContextManager.reset();
|
||||||
|
List<PythonVariable> inputs = new ArrayList<>();
|
||||||
|
inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
|
||||||
|
inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
|
||||||
|
String code = "z = x + y";
|
||||||
|
List<PythonVariable> outputs = PythonExecutioner.execAndReturnAllVariables(code, inputs);
|
||||||
|
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(3), outputs.get(0).getValue());
|
||||||
|
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(4), outputs.get(1).getValue());
|
||||||
|
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), outputs.get(2).getValue());
|
||||||
|
}
|
||||||
|
} catch (Throwable e) {
|
||||||
|
exceptions.add(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
int numThreads = 10;
|
||||||
|
Thread[] threads = new Thread[numThreads];
|
||||||
|
for (int i = 0; i < threads.length; i++) {
|
||||||
|
threads[i] = new Thread(runnable);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < threads.length; i++) {
|
||||||
|
threads[i].start();
|
||||||
|
}
|
||||||
|
Thread.sleep(100);
|
||||||
|
for (int i = 0; i < threads.length; i++) {
|
||||||
|
threads[i].join();
|
||||||
|
}
|
||||||
|
if (!exceptions.isEmpty()) {
|
||||||
|
throw (exceptions.get(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMultiThreading3() throws Throwable {
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code = "c = a + b";
|
||||||
|
final PythonJob job = new PythonJob("job1", code, false);
|
||||||
|
|
||||||
|
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
|
||||||
|
|
||||||
|
class JobThread extends Thread {
|
||||||
|
private INDArray a, b, c;
|
||||||
|
|
||||||
|
public JobThread(INDArray a, INDArray b, INDArray c) {
|
||||||
|
this.a = a;
|
||||||
|
this.b = b;
|
||||||
|
this.c = c;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void run() {
|
||||||
|
try {
|
||||||
|
PythonVariable<INDArray> out = new PythonVariable<>("c", NumpyArray.INSTANCE);
|
||||||
|
job.exec(Arrays.<PythonVariable>asList(new PythonVariable<>("a", NumpyArray.INSTANCE, a),
|
||||||
|
new PythonVariable<>("b", NumpyArray.INSTANCE, b)),
|
||||||
|
Collections.<PythonVariable>singletonList(out));
|
||||||
|
Assert.assertEquals(c, out.getValue());
|
||||||
|
} catch (Exception e) {
|
||||||
|
exceptions.add(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int numThreads = 10;
|
||||||
|
JobThread[] threads = new JobThread[numThreads];
|
||||||
|
for (int i = 0; i < threads.length; i++) {
|
||||||
|
threads[i] = new JobThread(Nd4j.zeros(dataType, 2, 3).add(i), Nd4j.zeros(dataType, 2, 3).add(i + 3),
|
||||||
|
Nd4j.zeros(dataType, 2, 3).add(2 * i + 3));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < threads.length; i++) {
|
||||||
|
threads[i].start();
|
||||||
|
}
|
||||||
|
Thread.sleep(100);
|
||||||
|
for (int i = 0; i < threads.length; i++) {
|
||||||
|
threads[i].join();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!exceptions.isEmpty()) {
|
||||||
|
throw (exceptions.get(0));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -14,15 +14,22 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.python4j.NumpyArray;
|
||||||
|
import org.nd4j.python4j.PythonTypes;
|
||||||
|
|
||||||
/**
|
import javax.annotation.concurrent.NotThreadSafe;
|
||||||
* An interface that is an extension of {@link QNetworkSource} for all implementations capable of supplying a target Q-Network
|
|
||||||
*
|
@NotThreadSafe
|
||||||
* @author Alexandre Boulanger
|
public class PythonNumpyServiceLoaderTest {
|
||||||
*/
|
|
||||||
public interface TargetQNetworkSource extends QNetworkSource {
|
@Test
|
||||||
IDQN getTargetQNetwork();
|
public void testServiceLoader(){
|
||||||
|
Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.<INDArray>get("numpy.ndarray"));
|
||||||
|
Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.getPythonTypeForJavaObject(Nd4j.zeros(1)));
|
||||||
|
}
|
||||||
}
|
}
|
|
@ -17,7 +17,6 @@ package org.deeplearning4j.rl4j.agent.update;
|
||||||
|
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.DoubleDQN;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.ITDTargetAlgorithm;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
import org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm.StandardDQN;
|
||||||
|
@ -28,13 +27,10 @@ import java.util.List;
|
||||||
|
|
||||||
// Temporary class that will be replaced with a more generic class that delegates gradient computation
|
// Temporary class that will be replaced with a more generic class that delegates gradient computation
|
||||||
// and network update to sub components.
|
// and network update to sub components.
|
||||||
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>, TargetQNetworkSource {
|
public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>> {
|
||||||
|
|
||||||
@Getter
|
|
||||||
private final IDQN qNetwork;
|
private final IDQN qNetwork;
|
||||||
|
private final IDQN targetQNetwork;
|
||||||
@Getter
|
|
||||||
private IDQN targetQNetwork;
|
|
||||||
private final int targetUpdateFrequency;
|
private final int targetUpdateFrequency;
|
||||||
|
|
||||||
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
|
private final ITDTargetAlgorithm<Integer> tdTargetAlgorithm;
|
||||||
|
@ -47,8 +43,8 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
|
||||||
this.targetQNetwork = qNetwork.clone();
|
this.targetQNetwork = qNetwork.clone();
|
||||||
this.targetUpdateFrequency = targetUpdateFrequency;
|
this.targetUpdateFrequency = targetUpdateFrequency;
|
||||||
tdTargetAlgorithm = isDoubleDQN
|
tdTargetAlgorithm = isDoubleDQN
|
||||||
? new DoubleDQN(this, gamma, errorClamp)
|
? new DoubleDQN(qNetwork, targetQNetwork, gamma, errorClamp)
|
||||||
: new StandardDQN(this, gamma, errorClamp);
|
: new StandardDQN(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -56,7 +52,7 @@ public class DQNNeuralNetUpdateRule implements IUpdateRule<Transition<Integer>>,
|
||||||
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
|
DataSet targets = tdTargetAlgorithm.computeTDTargets(trainingBatch);
|
||||||
qNetwork.fit(targets.getFeatures(), targets.getLabels());
|
qNetwork.fit(targets.getFeatures(), targets.getLabels());
|
||||||
if(++updateCount % targetUpdateFrequency == 0) {
|
if(++updateCount % targetUpdateFrequency == 0) {
|
||||||
targetQNetwork = qNetwork.clone();
|
targetQNetwork.copy(qNetwork);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,8 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -28,7 +27,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
*/
|
*/
|
||||||
public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
|
public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
|
||||||
|
|
||||||
private final TargetQNetworkSource qTargetNetworkSource;
|
private final IOutputNeuralNet targetQNetwork;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* In litterature, this corresponds to Q{net}(s(t+1), a)
|
* In litterature, this corresponds to Q{net}(s(t+1), a)
|
||||||
|
@ -40,23 +39,21 @@ public abstract class BaseDQNAlgorithm extends BaseTDTargetAlgorithm {
|
||||||
*/
|
*/
|
||||||
protected INDArray targetQNetworkNextObservation;
|
protected INDArray targetQNetworkNextObservation;
|
||||||
|
|
||||||
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
|
||||||
super(qTargetNetworkSource, gamma);
|
super(qNetwork, gamma);
|
||||||
this.qTargetNetworkSource = qTargetNetworkSource;
|
this.targetQNetwork = targetQNetwork;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected BaseDQNAlgorithm(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
protected BaseDQNAlgorithm(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
|
||||||
super(qTargetNetworkSource, gamma, errorClamp);
|
super(qNetwork, gamma, errorClamp);
|
||||||
this.qTargetNetworkSource = qTargetNetworkSource;
|
this.targetQNetwork = targetQNetwork;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void initComputation(INDArray observations, INDArray nextObservations) {
|
protected void initComputation(INDArray observations, INDArray nextObservations) {
|
||||||
super.initComputation(observations, nextObservations);
|
super.initComputation(observations, nextObservations);
|
||||||
|
|
||||||
qNetworkNextObservation = qNetworkSource.getQNetwork().output(nextObservations);
|
qNetworkNextObservation = qNetwork.output(nextObservations);
|
||||||
|
|
||||||
IDQN targetQNetwork = qTargetNetworkSource.getTargetQNetwork();
|
|
||||||
targetQNetworkNextObservation = targetQNetwork.output(nextObservations);
|
targetQNetworkNextObservation = targetQNetwork.output(nextObservations);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.QNetworkSource;
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ import java.util.List;
|
||||||
*/
|
*/
|
||||||
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
|
public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Integer> {
|
||||||
|
|
||||||
protected final QNetworkSource qNetworkSource;
|
protected final IOutputNeuralNet qNetwork;
|
||||||
protected final double gamma;
|
protected final double gamma;
|
||||||
|
|
||||||
private final double errorClamp;
|
private final double errorClamp;
|
||||||
|
@ -38,12 +38,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param qNetworkSource The source of the Q-Network
|
* @param qNetwork The Q-Network
|
||||||
* @param gamma The discount factor
|
* @param gamma The discount factor
|
||||||
* @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping.
|
* @param errorClamp Will prevent the new Q-Value from being farther than <i>errorClamp</i> away from the previous value. Double.NaN will disable the clamping.
|
||||||
*/
|
*/
|
||||||
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma, double errorClamp) {
|
protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma, double errorClamp) {
|
||||||
this.qNetworkSource = qNetworkSource;
|
this.qNetwork = qNetwork;
|
||||||
this.gamma = gamma;
|
this.gamma = gamma;
|
||||||
|
|
||||||
this.errorClamp = errorClamp;
|
this.errorClamp = errorClamp;
|
||||||
|
@ -52,12 +52,12 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
*
|
||||||
* @param qNetworkSource The source of the Q-Network
|
* @param qNetwork The Q-Network
|
||||||
* @param gamma The discount factor
|
* @param gamma The discount factor
|
||||||
* Note: Error clamping is disabled with this ctor
|
* Note: Error clamping is disabled with this ctor
|
||||||
*/
|
*/
|
||||||
protected BaseTDTargetAlgorithm(QNetworkSource qNetworkSource, double gamma) {
|
protected BaseTDTargetAlgorithm(IOutputNeuralNet qNetwork, double gamma) {
|
||||||
this(qNetworkSource, gamma, Double.NaN);
|
this(qNetwork, gamma, Double.NaN);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -89,8 +89,7 @@ public abstract class BaseTDTargetAlgorithm implements ITDTargetAlgorithm<Intege
|
||||||
|
|
||||||
initComputation(observations, nextObservations);
|
initComputation(observations, nextObservations);
|
||||||
|
|
||||||
INDArray updatedQValues = qNetworkSource.getQNetwork().output(observations);
|
INDArray updatedQValues = qNetwork.output(observations);
|
||||||
|
|
||||||
for (int i = 0; i < size; ++i) {
|
for (int i = 0; i < size; ++i) {
|
||||||
Transition<Integer> transition = transitions.get(i);
|
Transition<Integer> transition = transitions.get(i);
|
||||||
double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal());
|
double yTarget = computeTarget(i, transition.getReward(), transition.isTerminal());
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -32,12 +32,12 @@ public class DoubleDQN extends BaseDQNAlgorithm {
|
||||||
// In litterature, this corresponds to: max_{a}Q(s_{t+1}, a)
|
// In litterature, this corresponds to: max_{a}Q(s_{t+1}, a)
|
||||||
private INDArray maxActionsFromQNetworkNextObservation;
|
private INDArray maxActionsFromQNetworkNextObservation;
|
||||||
|
|
||||||
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
|
||||||
super(qTargetNetworkSource, gamma);
|
super(qNetwork, targetQNetwork, gamma);
|
||||||
}
|
}
|
||||||
|
|
||||||
public DoubleDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
public DoubleDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
|
||||||
super(qTargetNetworkSource, gamma, errorClamp);
|
super(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
@ -32,12 +32,12 @@ public class StandardDQN extends BaseDQNAlgorithm {
|
||||||
// In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a)
|
// In litterature, this corresponds to: max_{a}Q_{tar}(s_{t+1}, a)
|
||||||
private INDArray maxActionsFromQTargetNextObservation;
|
private INDArray maxActionsFromQTargetNextObservation;
|
||||||
|
|
||||||
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma) {
|
public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma) {
|
||||||
super(qTargetNetworkSource, gamma);
|
super(qNetwork, targetQNetwork, gamma);
|
||||||
}
|
}
|
||||||
|
|
||||||
public StandardDQN(TargetQNetworkSource qTargetNetworkSource, double gamma, double errorClamp) {
|
public StandardDQN(IOutputNeuralNet qNetwork, IOutputNeuralNet targetQNetwork, double gamma, double errorClamp) {
|
||||||
super(qTargetNetworkSource, gamma, errorClamp);
|
super(qNetwork, targetQNetwork, gamma, errorClamp);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -1,28 +1,38 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
*
|
*
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
* License for the specific language governing permissions and limitations
|
* License for the specific language governing permissions and limitations
|
||||||
* under the License.
|
* under the License.
|
||||||
*
|
*
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
package org.deeplearning4j.rl4j.network;
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning;
|
|
||||||
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* An interface for all implementations capable of supplying a Q-Network
|
* An interface defining the output aspect of a {@link NeuralNet}.
|
||||||
*
|
*/
|
||||||
* @author Alexandre Boulanger
|
public interface IOutputNeuralNet {
|
||||||
*/
|
/**
|
||||||
public interface QNetworkSource {
|
* Compute the output for the supplied observation.
|
||||||
IDQN getQNetwork();
|
* @param observation An {@link Observation}
|
||||||
}
|
* @return The ouptut of the network
|
||||||
|
*/
|
||||||
|
INDArray output(Observation observation);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Compute the output for the supplied batch.
|
||||||
|
* @param batch
|
||||||
|
* @return The ouptut of the network
|
||||||
|
*/
|
||||||
|
INDArray output(INDArray batch);
|
||||||
|
}
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.rl4j.network.dqn;
|
package org.deeplearning4j.rl4j.network.dqn;
|
||||||
|
|
||||||
import org.deeplearning4j.nn.gradient.Gradient;
|
import org.deeplearning4j.nn.gradient.Gradient;
|
||||||
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.deeplearning4j.rl4j.network.NeuralNet;
|
import org.deeplearning4j.rl4j.network.NeuralNet;
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -27,7 +28,7 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
* This neural net quantify the value of each action given a state
|
* This neural net quantify the value of each action given a state
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
|
public interface IDQN<NN extends IDQN> extends NeuralNet<NN>, IOutputNeuralNet {
|
||||||
|
|
||||||
boolean isRecurrent();
|
boolean isRecurrent();
|
||||||
|
|
||||||
|
@ -37,9 +38,6 @@ public interface IDQN<NN extends IDQN> extends NeuralNet<NN> {
|
||||||
|
|
||||||
void fit(INDArray input, INDArray[] labels);
|
void fit(INDArray input, INDArray[] labels);
|
||||||
|
|
||||||
INDArray output(INDArray batch);
|
|
||||||
INDArray output(Observation observation);
|
|
||||||
|
|
||||||
INDArray[] outputAll(INDArray batch);
|
INDArray[] outputAll(INDArray batch);
|
||||||
|
|
||||||
NN clone();
|
NN clone();
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -13,16 +16,29 @@ import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class DoubleDQNTest {
|
public class DoubleDQNTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IOutputNeuralNet qNetworkMock;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IOutputNeuralNet targetQNetworkMock;
|
||||||
|
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
||||||
|
|
||||||
// Assemble
|
// Assemble
|
||||||
MockDQN qNetwork = new MockDQN();
|
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||||
MockDQN targetQNetwork = new MockDQN();
|
|
||||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
|
@ -31,7 +47,7 @@ public class DoubleDQNTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -46,9 +62,7 @@ public class DoubleDQNTest {
|
||||||
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
||||||
|
|
||||||
// Assemble
|
// Assemble
|
||||||
MockDQN qNetwork = new MockDQN();
|
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
|
||||||
MockDQN targetQNetwork = new MockDQN(-1.0);
|
|
||||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
|
@ -57,7 +71,7 @@ public class DoubleDQNTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -72,9 +86,7 @@ public class DoubleDQNTest {
|
||||||
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
||||||
|
|
||||||
// Assemble
|
// Assemble
|
||||||
MockDQN qNetwork = new MockDQN();
|
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> ((INDArray)i.getArguments()[0]).mul(-1.0));
|
||||||
MockDQN targetQNetwork = new MockDQN(-1.0);
|
|
||||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
|
@ -87,7 +99,7 @@ public class DoubleDQNTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
DoubleDQN sut = new DoubleDQN(targetQNetworkSource, 0.5);
|
DoubleDQN sut = new DoubleDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
|
|
@ -1,10 +1,13 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
package org.deeplearning4j.rl4j.learning.sync.qlearning.discrete.TDTargetAlgorithm;
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
import org.deeplearning4j.rl4j.learning.sync.Transition;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockDQN;
|
import org.deeplearning4j.rl4j.network.IOutputNeuralNet;
|
||||||
import org.deeplearning4j.rl4j.learning.sync.support.MockTargetQNetworkSource;
|
|
||||||
import org.deeplearning4j.rl4j.observation.Observation;
|
import org.deeplearning4j.rl4j.observation.Observation;
|
||||||
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.runner.RunWith;
|
||||||
|
import org.mockito.Mock;
|
||||||
|
import org.mockito.junit.MockitoJUnitRunner;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.api.DataSet;
|
import org.nd4j.linalg.dataset.api.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -12,17 +15,31 @@ import org.nd4j.linalg.factory.Nd4j;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class StandardDQNTest {
|
public class StandardDQNTest {
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IOutputNeuralNet qNetworkMock;
|
||||||
|
|
||||||
|
@Mock
|
||||||
|
IOutputNeuralNet targetQNetworkMock;
|
||||||
|
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
when(qNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||||
|
when(targetQNetworkMock.output(any(INDArray.class))).thenAnswer(i -> i.getArguments()[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
public void when_isTerminal_expect_rewardValueAtIdx0() {
|
||||||
|
|
||||||
// Assemble
|
// Assemble
|
||||||
MockDQN qNetwork = new MockDQN();
|
|
||||||
MockDQN targetQNetwork = new MockDQN();
|
|
||||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||||
|
@ -30,7 +47,7 @@ public class StandardDQNTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -45,10 +62,6 @@ public class StandardDQNTest {
|
||||||
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
public void when_isNotTerminal_expect_rewardPlusEstimatedQValue() {
|
||||||
|
|
||||||
// Assemble
|
// Assemble
|
||||||
MockDQN qNetwork = new MockDQN();
|
|
||||||
MockDQN targetQNetwork = new MockDQN();
|
|
||||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||||
|
@ -56,7 +69,7 @@ public class StandardDQNTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
@ -71,10 +84,6 @@ public class StandardDQNTest {
|
||||||
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
public void when_batchHasMoreThanOne_expect_everySampleEvaluated() {
|
||||||
|
|
||||||
// Assemble
|
// Assemble
|
||||||
MockDQN qNetwork = new MockDQN();
|
|
||||||
MockDQN targetQNetwork = new MockDQN();
|
|
||||||
MockTargetQNetworkSource targetQNetworkSource = new MockTargetQNetworkSource(qNetwork, targetQNetwork);
|
|
||||||
|
|
||||||
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
List<Transition<Integer>> transitions = new ArrayList<Transition<Integer>>() {
|
||||||
{
|
{
|
||||||
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
add(buildTransition(buildObservation(new double[]{1.1, 2.2}),
|
||||||
|
@ -86,7 +95,7 @@ public class StandardDQNTest {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
StandardDQN sut = new StandardDQN(targetQNetworkSource, 0.5);
|
StandardDQN sut = new StandardDQN(qNetworkMock, targetQNetworkMock, 0.5);
|
||||||
|
|
||||||
// Act
|
// Act
|
||||||
DataSet result = sut.computeTDTargets(transitions);
|
DataSet result = sut.computeTDTargets(transitions);
|
||||||
|
|
|
@ -1,26 +0,0 @@
|
||||||
package org.deeplearning4j.rl4j.learning.sync.support;
|
|
||||||
|
|
||||||
import org.deeplearning4j.rl4j.learning.sync.qlearning.TargetQNetworkSource;
|
|
||||||
import org.deeplearning4j.rl4j.network.dqn.IDQN;
|
|
||||||
|
|
||||||
public class MockTargetQNetworkSource implements TargetQNetworkSource {
|
|
||||||
|
|
||||||
|
|
||||||
private final IDQN qNetwork;
|
|
||||||
private final IDQN targetQNetwork;
|
|
||||||
|
|
||||||
public MockTargetQNetworkSource(IDQN qNetwork, IDQN targetQNetwork) {
|
|
||||||
this.qNetwork = qNetwork;
|
|
||||||
this.targetQNetwork = targetQNetwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public IDQN getTargetQNetwork() {
|
|
||||||
return targetQNetwork;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public IDQN getQNetwork() {
|
|
||||||
return qNetwork;
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue