python4j-numpy (#475)

* 'initial'

* 'impl'

* tests

* <T>

* more tests

* scalar fixes

* lazy setup jobs

* more tests

* multithreading wip

* multithreading fix

* bytebuffer working

* nits

* inplace exec fixes

* attempt linux cpu fix

* rollback

* list fixes

* disable gc

* log

* bump jcpp + fixes

* #8985 GradientSharingTrainingTest ignore for logged issue

Signed-off-by: Alex Black <blacka101@gmail.com>

* memview fixes

* fix?

Co-authored-by: Alex Black <blacka101@gmail.com>
master
Fariz Rahman 2020-06-16 05:43:10 +04:00 committed by GitHub
parent bb0492f47d
commit 9ca679e080
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 1828 additions and 52 deletions

View File

@ -141,7 +141,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats());
System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
// System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
File f = testDir.newFolder();
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
int count = 0;
@ -208,10 +208,10 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
}
INDArray paramsAfter = after.params();
System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
System.out.println(Arrays.toString(
Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
// System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
// System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
// System.out.println(Arrays.toString(
// Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
assertNotEquals(paramsBefore, paramsAfter);
@ -235,7 +235,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
}
@Test
@Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985
public void differentNetsTrainingTest() throws Exception {
int batch = 3;

View File

@ -41,10 +41,14 @@
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.6.6</version>
</dependency> <dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>${logback.version}</version>
<scope>test</scope>
<scope>test</scope>
</dependency>
<dependency>
<groupId>junit</groupId>
@ -62,5 +66,10 @@
<artifactId>jsr305</artifactId>
<version>3.0.2</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.6.6</version>
</dependency>
</dependencies>
</project>

View File

@ -39,6 +39,5 @@
<artifactId>cpython-platform</artifactId>
<version>${cpython-platform.version}</version>
</dependency>
</dependencies>
</project>

View File

@ -19,8 +19,10 @@ package org.eclipse.python4j;
import javax.lang.model.SourceVersion;
import java.io.Closeable;
import java.util.HashSet;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
/**
@ -46,6 +48,31 @@ public class PythonContextManager {
init();
}
public static class Context implements Closeable{
private final String name;
private final String previous;
private final boolean temp;
public Context(){
name = "temp_" + UUID.randomUUID().toString().replace("-", "_");
temp = true;
previous = getCurrentContext();
setContext(name);
}
public Context(String name){
this.name = name;
temp = false;
previous = getCurrentContext();
setContext(name);
}
@Override
public void close(){
setContext(previous);
if (temp) deleteContext(name);
}
}
private static void init() {
if (init.get()) return;
new PythonExecutioner();
@ -190,6 +217,7 @@ public class PythonContextManager {
setContext(tempContext);
deleteContext(currContext);
setContext(currContext);
deleteContext(tempContext);
}
/**

View File

@ -25,6 +25,7 @@ import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
@ -42,7 +43,6 @@ public class PythonExecutioner {
private final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.eclipse.python4j.path";
private final static String JAVACPP_PYTHON_APPEND_TYPE = "org.eclipse.python4j.path.append";
private final static String DEFAULT_APPEND_TYPE = "before";
static {
init();
}
@ -55,6 +55,11 @@ public class PythonExecutioner {
initPythonPath();
PyEval_InitThreads();
Py_InitializeEx(0);
for (PythonType type: PythonTypes.get()){
type.init();
}
// Constructors of custom types may contain initialization code that should
// run on the main the thread.
}
/**
@ -110,6 +115,8 @@ public class PythonExecutioner {
getVariables(Arrays.asList(pyVars));
}
/**
* Gets the variable with the given name from the interpreter.
*
@ -205,9 +212,9 @@ public class PythonExecutioner {
*
* @return
*/
public static List<PythonVariable> getAllVariables() {
public static PythonVariables getAllVariables() {
PythonGIL.assertThreadSafe();
List<PythonVariable> ret = new ArrayList<>();
PythonVariables ret = new PythonVariables();
PyObject main = PyImport_ImportModule("__main__");
PyObject globals = PyModule_GetDict(main);
PyObject keys = PyDict_Keys(globals);
@ -259,7 +266,7 @@ public class PythonExecutioner {
* @param inputs
* @return
*/
public static List<PythonVariable> execAndReturnAllVariables(String code, List<PythonVariable> inputs) {
public static PythonVariables execAndReturnAllVariables(String code, List<PythonVariable> inputs) {
setVariables(inputs);
simpleExec(getWrappedCode(code));
return getAllVariables();
@ -271,7 +278,7 @@ public class PythonExecutioner {
* @param code
* @return
*/
public static List<PythonVariable> execAndReturnAllVariables(String code) {
public static PythonVariables execAndReturnAllVariables(String code) {
simpleExec(getWrappedCode(code));
return getAllVariables();
}
@ -279,25 +286,22 @@ public class PythonExecutioner {
private static synchronized void initPythonPath() {
try {
String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY);
List<File> packagesList = new ArrayList<>();
packagesList.addAll(Arrays.asList(cachePackages()));
for (PythonType type: PythonTypes.get()){
packagesList.addAll(Arrays.asList(type.packages()));
}
//// TODO: fix in javacpp
packagesList.add(new File(python.cachePackage(), "site-packages"));
File[] packages = packagesList.toArray(new File[0]);
if (path == null) {
File[] packages = cachePackages();
//// TODO: fix in javacpp
File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
File[] packages2 = new File[packages.length + 1];
for (int i = 0; i < packages.length; i++) {
//System.out.println(packages[i].getAbsolutePath());
packages2[i] = packages[i];
}
packages2[packages.length] = sitePackagesWindows;
//System.out.println(sitePackagesWindows.getAbsolutePath());
packages = packages2;
//////////
Py_SetPath(packages);
} else {
StringBuffer sb = new StringBuffer();
File[] packages = cachePackages();
JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase());
switch (pathAppendValue) {
case BEFORE:

View File

@ -90,4 +90,8 @@ public class PythonGIL implements AutoCloseable {
PyEval_SaveThread();
PyEval_RestoreThread(mainThreadState);
}
public static boolean locked(){
return acquired.get();
}
}

View File

@ -20,25 +20,29 @@ package org.eclipse.python4j;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import javax.annotation.Nonnull;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
@Data
@NoArgsConstructor
/**
* PythonJob is the right abstraction for executing multiple python scripts
* in a multi thread stateful environment. The setup-and-run mode allows your
* "setup" code (imports, model loading etc) to be executed only once.
*/
@Data
@Slf4j
public class PythonJob {
private String code;
private String name;
private String context;
private boolean setupRunMode;
private final boolean setupRunMode;
private PythonObject runF;
private final AtomicBoolean setupDone = new AtomicBoolean(false);
static {
new PythonExecutioner();
@ -63,7 +67,6 @@ public class PythonJob {
if (PythonContextManager.hasContext(context)) {
throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!");
}
if (setupRunMode) setup();
}
@ -71,17 +74,18 @@ public class PythonJob {
* Clears all variables in current context and calls setup()
*/
public void clearState(){
String context = this.context;
PythonContextManager.setContext("main");
PythonContextManager.deleteContext(context);
this.context = context;
PythonContextManager.setContext(this.context);
PythonContextManager.reset();
setupDone.set(false);
setup();
}
public void setup(){
if (setupDone.get()) return;
try (PythonGIL gil = PythonGIL.lock()) {
PythonContextManager.setContext(context);
PythonObject runF = PythonExecutioner.getVariable("run");
if (runF == null || runF.isNone() || !Python.callable(runF)) {
PythonExecutioner.exec(code);
runF = PythonExecutioner.getVariable("run");
@ -98,10 +102,12 @@ public class PythonJob {
if (!setupF.isNone()) {
setupF.call();
}
setupDone.set(true);
}
}
public void exec(List<PythonVariable> inputs, List<PythonVariable> outputs) {
if (setupRunMode)setup();
try (PythonGIL gil = PythonGIL.lock()) {
try (PythonGC _ = PythonGC.watch()) {
PythonContextManager.setContext(context);
@ -139,6 +145,7 @@ public class PythonJob {
}
public List<PythonVariable> execAndReturnAllVariables(List<PythonVariable> inputs){
if (setupRunMode)setup();
try (PythonGIL gil = PythonGIL.lock()) {
try (PythonGC _ = PythonGC.watch()) {
PythonContextManager.setContext(context);

View File

@ -147,7 +147,8 @@ public class PythonObject {
}
PythonObject pyArgs;
PythonObject pyKwargs;
if (args == null) {
if (args == null || args.isEmpty()) {
pyArgs = new PythonObject(PyTuple_New(0));
} else {
PythonObject argsList = PythonTypes.convert(args);
@ -158,6 +159,7 @@ public class PythonObject {
} else {
pyKwargs = PythonTypes.convert(kwargs);
}
PythonObject ret = new PythonObject(
PyObject_Call(
nativePythonObject,
@ -165,7 +167,9 @@ public class PythonObject {
pyKwargs == null ? null : pyKwargs.nativePythonObject
)
);
PythonGC.keep(ret);
return ret;
}
@ -241,4 +245,48 @@ public class PythonObject {
PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject);
}
public PythonObject abs(){
return new PythonObject(PyNumber_Absolute(nativePythonObject));
}
public PythonObject add(PythonObject pythonObject){
return new PythonObject(PyNumber_Add(nativePythonObject, pythonObject.nativePythonObject));
}
public PythonObject sub(PythonObject pythonObject){
return new PythonObject(PyNumber_Subtract(nativePythonObject, pythonObject.nativePythonObject));
}
public PythonObject mod(PythonObject pythonObject){
return new PythonObject(PyNumber_Divmod(nativePythonObject, pythonObject.nativePythonObject));
}
public PythonObject mul(PythonObject pythonObject){
return new PythonObject(PyNumber_Multiply(nativePythonObject, pythonObject.nativePythonObject));
}
public PythonObject trueDiv(PythonObject pythonObject){
return new PythonObject(PyNumber_TrueDivide(nativePythonObject, pythonObject.nativePythonObject));
}
public PythonObject floorDiv(PythonObject pythonObject){
return new PythonObject(PyNumber_FloorDivide(nativePythonObject, pythonObject.nativePythonObject));
}
public PythonObject matMul(PythonObject pythonObject){
return new PythonObject(PyNumber_MatrixMultiply(nativePythonObject, pythonObject.nativePythonObject));
}
public void addi(PythonObject pythonObject){
PyNumber_InPlaceAdd(nativePythonObject, pythonObject.nativePythonObject);
}
public void subi(PythonObject pythonObject){
PyNumber_InPlaceSubtract(nativePythonObject, pythonObject.nativePythonObject);
}
public void muli(PythonObject pythonObject){
PyNumber_InPlaceMultiply(nativePythonObject, pythonObject.nativePythonObject);
}
public void trueDivi(PythonObject pythonObject){
PyNumber_InPlaceTrueDivide(nativePythonObject, pythonObject.nativePythonObject);
}
public void floorDivi(PythonObject pythonObject){
PyNumber_InPlaceFloorDivide(nativePythonObject, pythonObject.nativePythonObject);
}
public void matMuli(PythonObject pythonObject){
PyNumber_InPlaceMatrixMultiply(nativePythonObject, pythonObject.nativePythonObject);
}
}

View File

@ -0,0 +1,127 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.eclipse.python4j;
import org.apache.commons.io.IOUtils;
import org.bytedeco.javacpp.Loader;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
public class PythonProcess {
private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
ProcessBuilder pb = new ProcessBuilder(allArgs);
Process process = pb.start();
String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
process.waitFor();
return out;
}
public static void run(String... arguments)throws IOException, InterruptedException{
String[] allArgs = new String[arguments.length + 1];
for (int i = 0; i < arguments.length; i++){
allArgs[i + 1] = arguments[i];
}
allArgs[0] = pythonExecutable;
ProcessBuilder pb = new ProcessBuilder(allArgs);
pb.inheritIO().start().waitFor();
}
public static void pipInstall(String packageName) throws PythonException{
try{
run("-m", "pip", "install", packageName);
}catch(Exception e){
throw new PythonException("Error installing package " + packageName, e);
}
}
public static void pipInstall(String packageName, String version){
pipInstall(packageName + "==" + version);
}
public static void pipUninstall(String packageName) throws PythonException{
try{
run("-m", "pip", "uninstall", packageName);
}catch(Exception e){
throw new PythonException("Error uninstalling package " + packageName, e);
}
}
public static void pipInstallFromGit(String gitRepoUrl){
if (!gitRepoUrl.contains("://")){
gitRepoUrl = "git://" + gitRepoUrl;
}
try{
run("-m", "pip", "install", "git+", gitRepoUrl);
}catch(Exception e){
throw new PythonException("Error installing package from " + gitRepoUrl, e);
}
}
public static String getPackageVersion(String packageName){
String out;
try{
out = runAndReturn("-m", "pip", "show", packageName);
} catch (Exception e){
throw new PythonException("Error finding version for package " + packageName, e);
}
if (!out.contains("Version: ")){
throw new PythonException("Can't find package " + packageName);
}
String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
return pkgVersion;
}
public static boolean isPackageInstalled(String packageName){
try{
String out = runAndReturn("-m", "pip", "show", packageName);
return !out.isEmpty();
}catch (Exception e){
throw new PythonException("Error checking if package is installed: " +packageName, e);
}
}
public static void pipInstallFromRequirementsTxt(String path){
try{
run("-m", "pip", "install","-r", path);
}catch (Exception e){
throw new PythonException("Error installing packages from " + path, e);
}
}
public static void pipInstallFromSetupScript(String path, boolean inplace){
try{
run(path, inplace?"develop":"install");
}catch (Exception e){
throw new PythonException("Error installing package from " + path, e);
}
}
}

View File

@ -17,6 +17,8 @@
package org.eclipse.python4j;
import java.io.File;
public abstract class PythonType<T> {
private final String name;
@ -43,5 +45,25 @@ public abstract class PythonType<T> {
return name;
}
@Override
public boolean equals(Object obj){
if (!(obj instanceof PythonType)){
return false;
}
PythonType other = (PythonType)obj;
return this.getClass().equals(other.getClass()) && this.name.equals(other.name);
}
public PythonObject pythonType(){
return null;
}
public File[] packages(){
return new File[0];
}
public void init(){ //not to be called from constructor
}
}

View File

@ -18,7 +18,16 @@ package org.eclipse.python4j;
import org.bytedeco.cpython.PyObject;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import sun.misc.Unsafe;
import sun.nio.ch.DirectBuffer;
import java.lang.reflect.Field;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
import static org.bytedeco.cpython.global.python.*;
@ -28,7 +37,7 @@ public class PythonTypes {
private static List<PythonType> getPrimitiveTypes() {
return Arrays.<PythonType>asList(STR, INT, FLOAT, BOOL);
return Arrays.<PythonType>asList(STR, INT, FLOAT, BOOL, MEMORYVIEW);
}
private static List<PythonType> getCollectionTypes() {
@ -36,8 +45,13 @@ public class PythonTypes {
}
private static List<PythonType> getExternalTypes() {
//TODO service loader
return new ArrayList<>();
List<PythonType> ret = new ArrayList<>();
ServiceLoader<PythonType> sl = ServiceLoader.load(PythonType.class);
Iterator<PythonType> iter = sl.iterator();
while (iter.hasNext()) {
ret.add(iter.next());
}
return ret;
}
public static List<PythonType> get() {
@ -48,15 +62,17 @@ public class PythonTypes {
return ret;
}
public static PythonType get(String name) {
public static <T> PythonType<T> get(String name) {
for (PythonType pt : get()) {
if (pt.getName().equals(name)) { // TODO use map instead?
return pt;
}
}
throw new PythonException("Unknown python type: " + name);
}
public static PythonType getPythonTypeForJavaObject(Object javaObject) {
for (PythonType pt : get()) {
if (pt.accepts(javaObject)) {
@ -66,7 +82,7 @@ public class PythonTypes {
throw new PythonException("Unable to find python type for java type: " + javaObject.getClass());
}
public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) {
public static <T> PythonType<T> getPythonTypeForPythonObject(PythonObject pythonObject) {
PyObject pyType = PyObject_Type(pythonObject.getNativePythonObject());
try {
String pyTypeStr = PythonTypes.STR.toJava(new PythonObject(pyType, false));
@ -75,6 +91,14 @@ public class PythonTypes {
String pyTypeStr2 = "<class '" + pt.getName() + "'>";
if (pyTypeStr.equals(pyTypeStr2)) {
return pt;
} else {
try (PythonGC gc = PythonGC.watch()) {
PythonObject pyType2 = pt.pythonType();
if (pyType2 != null && Python.isinstance(pythonObject, pyType2)) {
return pt;
}
}
}
}
throw new PythonException("Unable to find converter for python object of type " + pyTypeStr);
@ -212,12 +236,49 @@ public class PythonTypes {
public static final PythonType<List> LIST = new PythonType<List>("list", List.class) {
@Override
public boolean accepts(Object javaObject) {
return (javaObject instanceof List || javaObject.getClass().isArray());
}
@Override
public List adapt(Object javaObject) {
if (javaObject instanceof List) {
return (List) javaObject;
} else if (javaObject instanceof Object[]) {
return Arrays.asList((Object[]) javaObject);
} else if (javaObject.getClass().isArray()) {
List<Object> ret = new ArrayList<>();
if (javaObject instanceof Object[]) {
Object[] arr = (Object[]) javaObject;
return new ArrayList<>(Arrays.asList(arr));
} else if (javaObject instanceof short[]) {
short[] arr = (short[]) javaObject;
for (short x : arr) ret.add(x);
return ret;
} else if (javaObject instanceof int[]) {
int[] arr = (int[]) javaObject;
for (int x : arr) ret.add(x);
return ret;
} else if (javaObject instanceof long[]) {
long[] arr = (long[]) javaObject;
for (long x : arr) ret.add(x);
return ret;
} else if (javaObject instanceof float[]) {
float[] arr = (float[]) javaObject;
for (float x : arr) ret.add(x);
return ret;
} else if (javaObject instanceof double[]) {
double[] arr = (double[]) javaObject;
for (double x : arr) ret.add(x);
return ret;
} else if (javaObject instanceof boolean[]) {
boolean[] arr = (boolean[]) javaObject;
for (boolean x : arr) ret.add(x);
return ret;
} else {
throw new PythonException("Unsupported array type: " + javaObject.getClass().toString());
}
} else {
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to List");
}
@ -327,7 +388,13 @@ public class PythonTypes {
}
Object v = javaObject.get(k);
PythonObject pyVal;
pyVal = PythonTypes.convert(v);
if (v instanceof PythonObject) {
pyVal = (PythonObject) v;
} else if (v instanceof PyObject) {
pyVal = new PythonObject((PyObject) v);
} else {
pyVal = PythonTypes.convert(v);
}
int errCode = PyDict_SetItem(pyDict, pyKey.getNativePythonObject(), pyVal.getNativePythonObject());
if (errCode != 0) {
String keyStr = pyKey.toString();
@ -341,4 +408,85 @@ public class PythonTypes {
return new PythonObject(pyDict);
}
};
public static final PythonType<BytePointer> MEMORYVIEW = new PythonType<BytePointer>("memoryview", BytePointer.class) {
@Override
public BytePointer toJava(PythonObject pythonObject) {
try (PythonGC gc = PythonGC.watch()) {
if (!(Python.isinstance(pythonObject, Python.memoryviewType()))) {
throw new PythonException("Expected memoryview. Received: " + pythonObject);
}
PythonObject pySize = Python.len(pythonObject);
PythonObject ctypes = Python.importModule("ctypes");
PythonObject charType = ctypes.attr("c_char");
PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(),
pySize.getNativePythonObject()));
PythonObject fromBuffer = charArrayType.attr("from_buffer");
if (pythonObject.attr("readonly").toBoolean()) {
pythonObject = Python.bytearray(pythonObject);
}
PythonObject arr = fromBuffer.call(pythonObject);
PythonObject cast = ctypes.attr("cast");
PythonObject voidPtrType = ctypes.attr("c_void_p");
PythonObject voidPtr = cast.call(arr, voidPtrType);
long address = voidPtr.attr("value").toLong();
long size = pySize.toLong();
try {
Field addressField = Buffer.class.getDeclaredField("address");
addressField.setAccessible(true);
Field capacityField = Buffer.class.getDeclaredField("capacity");
capacityField.setAccessible(true);
ByteBuffer buff = ByteBuffer.allocateDirect(0).order(ByteOrder.nativeOrder());
addressField.setLong(buff, address);
capacityField.setInt(buff, (int) size);
BytePointer ret = new BytePointer(buff);
ret.limit(size);
return ret;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
@Override
public PythonObject toPython(BytePointer javaObject) {
long address = javaObject.address();
long size = javaObject.limit();
try (PythonGC gc = PythonGC.watch()) {
PythonObject ctypes = Python.importModule("ctypes");
PythonObject charType = ctypes.attr("c_char");
PythonObject pySize = new PythonObject(size);
PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(),
pySize.getNativePythonObject()));
PythonObject fromAddress = charArrayType.attr("from_address");
PythonObject arr = fromAddress.call(new PythonObject(address));
PythonObject memoryView = Python.memoryview(arr).attr("cast").call("b");
PythonGC.keep(memoryView);
return memoryView;
}
}
@Override
public boolean accepts(Object javaObject) {
return javaObject instanceof Pointer || javaObject instanceof DirectBuffer;
}
@Override
public BytePointer adapt(Object javaObject) {
if (javaObject instanceof BytePointer) {
return (BytePointer) javaObject;
} else if (javaObject instanceof Pointer) {
return new BytePointer((Pointer) javaObject);
} else if (javaObject instanceof DirectBuffer) {
return new BytePointer((ByteBuffer) javaObject);
} else {
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to BytePointer");
}
}
};
}

View File

@ -0,0 +1,47 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.eclipse.python4j;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Some syntax sugar for lookup by name
*/
public class PythonVariables extends ArrayList<PythonVariable> {
public PythonVariable get(String variableName) {
for (PythonVariable pyVar: this){
if (pyVar.getName().equals(variableName)){
return pyVar;
}
}
return null;
}
public <T> boolean add(String variableName, PythonType<T> variableType, Object value){
return this.add(new PythonVariable<>(variableName, variableType, value));
}
public PythonVariables(PythonVariable... variables){
this(Arrays.asList(variables));
}
public PythonVariables(List<PythonVariable> list){
super();
addAll(list);
}
}

View File

@ -0,0 +1,113 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Loader;
import org.eclipse.python4j.*;
import org.junit.Assert;
import org.junit.Test;
import sun.nio.ch.DirectBuffer;
import javax.annotation.concurrent.NotThreadSafe;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.util.*;
@NotThreadSafe
public class PythonBufferTest {
@Test
public void testBuffer() {
ByteBuffer buff = ByteBuffer.allocateDirect(3);
buff.put((byte) 97);
buff.put((byte) 98);
buff.put((byte) 99);
buff.rewind();
BytePointer bp = new BytePointer(buff);
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("buff", PythonTypes.MEMORYVIEW, buff));
List<PythonVariable> outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("s1", PythonTypes.STR));
outputs.add(new PythonVariable<>("s2", PythonTypes.STR));
String code = "s1 = ''.join(chr(c) for c in buff)\nbuff[2] += 2\ns2 = ''.join(chr(c) for c in buff)";
PythonExecutioner.exec(code, inputs, outputs);
Assert.assertEquals("abc", outputs.get(0).getValue());
Assert.assertEquals("abe", outputs.get(1).getValue());
Assert.assertEquals(101, buff.get(2));
}
@Test
public void testBuffer2() {
ByteBuffer buff = ByteBuffer.allocateDirect(3);
buff.put((byte) 97);
buff.put((byte) 98);
buff.put((byte) 99);
buff.rewind();
BytePointer bp = new BytePointer(buff);
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("buff", PythonTypes.MEMORYVIEW, bp));
List<PythonVariable> outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("s1", PythonTypes.STR));
outputs.add(new PythonVariable<>("s2", PythonTypes.STR));
String code = "s1 = ''.join(chr(c) for c in buff)\nbuff[2] += 2\ns2 = ''.join(chr(c) for c in buff)";
PythonExecutioner.exec(code, inputs, outputs);
Assert.assertEquals("abc", outputs.get(0).getValue());
Assert.assertEquals("abe", outputs.get(1).getValue());
Assert.assertEquals(101, buff.get(2));
}
@Test
public void testBuffer3() {
ByteBuffer buff = ByteBuffer.allocateDirect(3);
buff.put((byte) 97);
buff.put((byte) 98);
buff.put((byte) 99);
buff.rewind();
BytePointer bp = new BytePointer(buff);
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("buff", PythonTypes.MEMORYVIEW, bp));
List<PythonVariable> outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("s1", PythonTypes.STR));
outputs.add(new PythonVariable<>("s2", PythonTypes.STR));
outputs.add(new PythonVariable<>("buff2", PythonTypes.MEMORYVIEW));
String code = "s1 = ''.join(chr(c) for c in buff)\nbuff[2] += 2\ns2 = ''.join(chr(c) for c in buff)\nbuff2=buff[1:]";
PythonExecutioner.exec(code, inputs, outputs);
Assert.assertEquals("abc", outputs.get(0).getValue());
Assert.assertEquals("abe", outputs.get(1).getValue());
Assert.assertEquals(101, buff.get(2));
BytePointer outBuffer = (BytePointer) outputs.get(2).getValue();
Assert.assertEquals(2, outBuffer.capacity());
Assert.assertEquals((byte)98, outBuffer.get(0));
Assert.assertEquals((byte)101, outBuffer.get(1));
}
}

View File

@ -49,6 +49,6 @@ public class PythonGCTest {
PythonObject pyObjCount3 = Python.len(getObjects.call());
long objCount3 = pyObjCount3.toLong();
diff = objCount3 - objCount2;
Assert.assertEquals(2, diff);// 2 objects created during function call
Assert.assertTrue(diff <= 2);// 2 objects created during function call
}
}

View File

@ -30,7 +30,7 @@ import static org.junit.Assert.assertEquals;
public class PythonJobTest {
@Test
public void testPythonJobBasic() throws Exception{
public void testPythonJobBasic(){
PythonContextManager.deleteNonMainContexts();
String code = "c = a + b";
@ -65,7 +65,7 @@ public class PythonJobTest {
}
@Test
public void testPythonJobReturnAllVariables()throws Exception{
public void testPythonJobReturnAllVariables(){
PythonContextManager.deleteNonMainContexts();
String code = "c = a + b";
@ -101,7 +101,7 @@ public class PythonJobTest {
@Test
public void testMultiplePythonJobsParallel()throws Exception{
public void testMultiplePythonJobsParallel(){
PythonContextManager.deleteNonMainContexts();
String code1 = "c = a + b";
PythonJob job1 = new PythonJob("job1", code1, false);
@ -150,7 +150,7 @@ public class PythonJobTest {
@Test
public void testPythonJobSetupRun()throws Exception{
public void testPythonJobSetupRun(){
PythonContextManager.deleteNonMainContexts();
String code = "five=None\n" +
@ -189,7 +189,7 @@ public class PythonJobTest {
}
@Test
public void testPythonJobSetupRunAndReturnAllVariables()throws Exception{
public void testPythonJobSetupRunAndReturnAllVariables(){
PythonContextManager.deleteNonMainContexts();
String code = "five=None\n" +
"c=None\n"+
@ -225,7 +225,7 @@ public class PythonJobTest {
}
@Test
public void testMultiplePythonJobsSetupRunParallel()throws Exception{
public void testMultiplePythonJobsSetupRunParallel(){
PythonContextManager.deleteNonMainContexts();
String code1 = "five=None\n" +

View File

@ -28,15 +28,50 @@
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse</groupId>
<artifactId>python4j-core</artifactId>
<version>1.0.0-SNAPSHOT</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-test-resources</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-10.1</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>dl4j-test-resources</artifactId>
<version>${nd4j.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
</profiles>
</project>

View File

@ -0,0 +1,303 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.eclipse.python4j;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.ArrayUtils;
import org.bytedeco.cpython.PyObject;
import org.bytedeco.cpython.PyTypeObject;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.numpy.PyArrayObject;
import org.bytedeco.numpy.global.numpy;
import org.nd4j.linalg.api.buffer.BaseDataBuffer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.MemoryWorkspaceManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.NativeOps;
import org.nd4j.nativeblas.NativeOpsHolder;
import java.io.File;
import java.util.*;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.bytedeco.cpython.global.python.*;
import static org.bytedeco.cpython.global.python.Py_DecRef;
import static org.bytedeco.numpy.global.numpy.*;
import static org.bytedeco.numpy.global.numpy.NPY_ARRAY_CARRAY;
import static org.bytedeco.numpy.global.numpy.PyArray_Type;
@Slf4j
public class NumpyArray extends PythonType<INDArray> {
public static final NumpyArray INSTANCE;
private static final AtomicBoolean init = new AtomicBoolean(false);
private static final Map<String, DataBuffer> cache = new HashMap<>();
static {
new PythonExecutioner();
INSTANCE = new NumpyArray();
}
@Override
public File[] packages(){
try{
return new File[]{numpy.cachePackage()};
}catch(Exception e){
throw new PythonException(e);
}
}
public synchronized void init() {
if (init.get()) return;
init.set(true);
if (PythonGIL.locked()) {
throw new PythonException("Can not initialize numpy - GIL already acquired.");
}
int err = numpy._import_array();
if (err < 0){
System.out.println("Numpy import failed!");
throw new PythonException("Numpy import failed!");
}
}
public NumpyArray() {
super("numpy.ndarray", INDArray.class);
}
@Override
public INDArray toJava(PythonObject pythonObject) {
log.info("Converting PythonObject to INDArray...");
PyObject np = PyImport_ImportModule("numpy");
PyObject ndarray = PyObject_GetAttrString(np, "ndarray");
if (PyObject_IsInstance(pythonObject.getNativePythonObject(), ndarray) != 1) {
Py_DecRef(ndarray);
Py_DecRef(np);
throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array.");
}
Py_DecRef(ndarray);
Py_DecRef(np);
PyArrayObject npArr = new PyArrayObject(pythonObject.getNativePythonObject());
long[] shape = new long[PyArray_NDIM(npArr)];
SizeTPointer shapePtr = PyArray_SHAPE(npArr);
if (shapePtr != null)
shapePtr.get(shape, 0, shape.length);
long[] strides = new long[shape.length];
SizeTPointer stridesPtr = PyArray_STRIDES(npArr);
if (stridesPtr != null)
stridesPtr.get(strides, 0, strides.length);
int npdtype = PyArray_TYPE(npArr);
DataType dtype;
switch (npdtype) {
case NPY_DOUBLE:
dtype = DataType.DOUBLE;
break;
case NPY_FLOAT:
dtype = DataType.FLOAT;
break;
case NPY_SHORT:
dtype = DataType.SHORT;
break;
case NPY_INT:
dtype = DataType.INT32;
break;
case NPY_LONG:
dtype = DataType.INT64;
break;
case NPY_UINT:
dtype = DataType.UINT32;
break;
case NPY_BYTE:
dtype = DataType.INT8;
break;
case NPY_UBYTE:
dtype = DataType.UINT8;
break;
case NPY_BOOL:
dtype = DataType.BOOL;
break;
case NPY_HALF:
dtype = DataType.FLOAT16;
break;
case NPY_LONGLONG:
dtype = DataType.INT64;
break;
case NPY_USHORT:
dtype = DataType.UINT16;
break;
case NPY_ULONG:
case NPY_ULONGLONG:
dtype = DataType.UINT64;
break;
default:
throw new PythonException("Unsupported array data type: " + npdtype);
}
long size = 1;
for (int i = 0; i < shape.length; size *= shape[i++]) ;
INDArray ret;
long address = PyArray_DATA(npArr).address();
String key = address + "_" + size + "_" + dtype;
DataBuffer buff = cache.get(key);
if (buff == null) {
try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address);
ptr = ptr.limit(size);
ptr = ptr.capacity(size);
buff = Nd4j.createBuffer(ptr, size, dtype);
cache.put(key, buff);
}
}
int elemSize = buff.getElementSize();
long[] nd4jStrides = new long[strides.length];
for (int i = 0; i < strides.length; i++) {
nd4jStrides[i] = strides[i] / elemSize;
}
ret = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype);
Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.HOST);
log.info("Done.");
return ret;
}
@Override
public PythonObject toPython(INDArray indArray) {
log.info("Converting INDArray to PythonObject...");
DataType dataType = indArray.dataType();
DataBuffer buff = indArray.data();
String key = buff.pointer().address() + "_" + buff.length() + "_" + dataType;
cache.put(key, buff);
int numpyType;
String ctype;
switch (dataType) {
case DOUBLE:
numpyType = NPY_DOUBLE;
ctype = "c_double";
break;
case FLOAT:
case BFLOAT16:
numpyType = NPY_FLOAT;
ctype = "c_float";
break;
case SHORT:
numpyType = NPY_SHORT;
ctype = "c_short";
break;
case INT:
numpyType = NPY_INT;
ctype = "c_int";
break;
case LONG:
numpyType = NPY_INT64;
ctype = "c_int64";
break;
case UINT16:
numpyType = NPY_USHORT;
ctype = "c_uint16";
break;
case UINT32:
numpyType = NPY_UINT;
ctype = "c_uint";
break;
case UINT64:
numpyType = NPY_UINT64;
ctype = "c_uint64";
break;
case BOOL:
numpyType = NPY_BOOL;
ctype = "c_bool";
break;
case BYTE:
numpyType = NPY_BYTE;
ctype = "c_byte";
break;
case UBYTE:
numpyType = NPY_UBYTE;
ctype = "c_ubyte";
break;
case HALF:
numpyType = NPY_HALF;
ctype = "c_short";
break;
default:
throw new RuntimeException("Unsupported dtype: " + dataType);
}
long[] shape = indArray.shape();
INDArray inputArray = indArray;
if (dataType == DataType.BFLOAT16) {
log.warn("Creating copy of array as bfloat16 is not supported by numpy.");
inputArray = indArray.castTo(DataType.FLOAT);
}
//Sync to host memory in the case of CUDA, before passing the host memory pointer to Python
Nd4j.getAffinityManager().ensureLocation(inputArray, AffinityManager.Location.HOST);
// PyArray_Type() call causes jvm crash in linux cpu if GIL is acquired by non main thread.
// Using Interpreter for now:
// try(PythonContextManager.Context context = new PythonContextManager.Context("__np_array_converter")){
// log.info("Stringing exec...");
// String code = "import ctypes\nimport numpy as np\n" +
// "cArr = (ctypes." + ctype + "*" + indArray.length() + ")"+
// ".from_address(" + indArray.data().pointer().address() + ")\n"+
// "npArr = np.frombuffer(cArr, dtype=" + ((numpyType == NPY_HALF) ? "'half'" : "ctypes." + ctype)+
// ").reshape(" + Arrays.toString(indArray.shape()) + ")";
// PythonExecutioner.exec(code);
// log.info("exec done.");
// PythonObject ret = PythonExecutioner.getVariable("npArr");
// Py_IncRef(ret.getNativePythonObject());
// return ret;
//
// }
log.info("NUMPY: PyArray_Type()");
PyTypeObject pyTypeObject = PyArray_Type();
log.info("NUMPY: PyArray_New()");
PyObject npArr = PyArray_New(pyTypeObject, shape.length, new SizeTPointer(shape),
numpyType, null,
inputArray.data().addressPointer(),
0, NPY_ARRAY_CARRAY, null);
log.info("Done.");
return new PythonObject(npArr);
}
@Override
public boolean accepts(Object javaObject) {
return javaObject instanceof INDArray;
}
@Override
public INDArray adapt(Object javaObject) {
if (javaObject instanceof INDArray) {
return (INDArray) javaObject;
}
throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to INDArray");
}
}

View File

@ -0,0 +1 @@
org.eclipse.python4j.NumpyArray

View File

@ -0,0 +1,170 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.eclipse.python4j.*;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.concurrency.AffinityManager;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.nativeblas.OpaqueDataBuffer;
import javax.annotation.concurrent.NotThreadSafe;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
@NotThreadSafe
@RunWith(Parameterized.class)
public class PythonNumpyBasicTest {
private DataType dataType;
private long[] shape;
public PythonNumpyBasicTest(DataType dataType, long[] shape, String dummyArg) {
this.dataType = dataType;
this.shape = shape;
}
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}, shape={2}")
public static Collection params() {
DataType[] types = new DataType[] {
DataType.BOOL,
DataType.FLOAT16,
DataType.BFLOAT16,
DataType.FLOAT,
DataType.DOUBLE,
DataType.INT8,
DataType.INT16,
DataType.INT32,
DataType.INT64,
DataType.UINT8,
DataType.UINT16,
DataType.UINT32,
DataType.UINT64
};
long[][] shapes = new long[][]{
new long[]{2, 3},
new long[]{3},
new long[]{1},
new long[]{} // scalar
};
List<Object[]> ret = new ArrayList<>();
for (DataType type: types){
for (long[] shape: shapes){
ret.add(new Object[]{type, shape, Arrays.toString(shape)});
}
}
return ret;
}
@Test
public void testConversion(){
INDArray arr = Nd4j.zeros(dataType, shape);
PythonObject npArr = PythonTypes.convert(arr);
INDArray arr2 = PythonTypes.<INDArray>getPythonTypeForPythonObject(npArr).toJava(npArr);
if (dataType == DataType.BFLOAT16){
arr = arr.castTo(DataType.FLOAT);
}
Assert.assertEquals(arr,arr2);
}
@Test
public void testExecution(){
List<PythonVariable> inputs = new ArrayList<>();
INDArray x = Nd4j.ones(dataType, shape);
INDArray y = Nd4j.zeros(dataType, shape);
INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2));
z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z;
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
inputs.add(new PythonVariable<>("x", arrType, x));
inputs.add(new PythonVariable<>("y", arrType, y));
List<PythonVariable> outputs = new ArrayList<>();
PythonVariable<INDArray> output = new PythonVariable<>("z", arrType);
outputs.add(output);
String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)";
if (shape.length == 0){ // scalar special case
code += "\nimport numpy as np\nz = np.asarray(float(z), dtype=x.dtype)";
}
PythonExecutioner.exec(code, inputs, outputs);
INDArray z2 = output.getValue();
Assert.assertEquals(z.dataType(), z2.dataType());
Assert.assertEquals(z, z2);
}
@Test
public void testInplaceExecution(){
if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return;
if (shape.length == 0) return;
List<PythonVariable> inputs = new ArrayList<>();
INDArray x = Nd4j.ones(dataType, shape);
INDArray y = Nd4j.zeros(dataType, shape);
INDArray z = x.mul(y.add(2));
// Nd4j.getAffinityManager().ensureLocation(z, AffinityManager.Location.HOST);
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
inputs.add(new PythonVariable<>("x", arrType, x));
inputs.add(new PythonVariable<>("y", arrType, y));
List<PythonVariable> outputs = new ArrayList<>();
PythonVariable<INDArray> output = new PythonVariable<>("x", arrType);
outputs.add(output);
String code = "x *= y + 2";
PythonExecutioner.exec(code, inputs, outputs);
INDArray z2 = output.getValue();
Assert.assertEquals(x.dataType(), z2.dataType());
Assert.assertEquals(z.dataType(), z2.dataType());
Assert.assertEquals(x, z2);
Assert.assertEquals(z, z2);
Assert.assertEquals(x.data().pointer().address(), z2.data().pointer().address());
if("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
Assert.assertEquals(getDeviceAddress(x), getDeviceAddress(z2));
}
}
private static long getDeviceAddress(INDArray array){
if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){
throw new IllegalStateException("Cannot ge device pointer for non-CUDA device");
}
//Use reflection here as OpaqueDataBuffer is only available on BaseCudaDataBuffer and BaseCpuDataBuffer - not DataBuffer/BaseDataBuffer
// due to it being defined in nd4j-native-api, not nd4j-api
try {
Class<?> c = Class.forName("org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer");
Method m = c.getMethod("getOpaqueDataBuffer");
OpaqueDataBuffer db = (OpaqueDataBuffer) m.invoke(array.data());
long address = db.specialBuffer().address();
return address;
} catch (Throwable t){
throw new RuntimeException("Error getting OpaqueDataBuffer", t);
}
}
}

View File

@ -0,0 +1,96 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.eclipse.python4j.PythonException;
import org.eclipse.python4j.PythonObject;
import org.eclipse.python4j.PythonTypes;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.*;
@NotThreadSafe
@RunWith(Parameterized.class)
public class PythonNumpyCollectionsTest {
private DataType dataType;
public PythonNumpyCollectionsTest(DataType dataType){
this.dataType = dataType;
}
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
public static DataType[] params() {
return new DataType[]{
DataType.BOOL,
DataType.FLOAT16,
//DataType.BFLOAT16,
DataType.FLOAT,
DataType.DOUBLE,
DataType.INT8,
DataType.INT16,
DataType.INT32,
DataType.INT64,
DataType.UINT8,
DataType.UINT16,
DataType.UINT32,
DataType.UINT64
};
}
@Test
public void testPythonDictFromMap() throws PythonException {
Map map = new HashMap();
map.put("a", 1);
map.put(1, "a");
map.put("arr", Nd4j.ones(dataType, 2, 3));
map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType,3,2)));
Map innerMap = new HashMap();
innerMap.put("b", 2);
innerMap.put(2, "b");
innerMap.put(5, Nd4j.ones(dataType, 5));
map.put("innermap", innerMap);
map.put("list2", Arrays.asList(4, "5", innerMap, false, true));
PythonObject dict = PythonTypes.convert(map);
Map map2 = PythonTypes.DICT.toJava(dict);
Assert.assertEquals(map.toString(), map2.toString());
}
@Test
public void testPythonListFromList() throws PythonException{
List<Object> list = new ArrayList<>();
list.add(1);
list.add("2");
list.add(Nd4j.ones(dataType, 2, 3));
list.add(Arrays.asList("a",
Nd4j.ones(dataType, 1, 2),1.0, 2f, 10, true, false,
Nd4j.zeros(dataType, 3, 2)));
Map map = new HashMap();
map.put("a", 1);
map.put(1, "a");
map.put(5, Nd4j.ones(dataType,4, 5));
map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType, 3, 1)));
list.add(map);
PythonObject dict = PythonTypes.convert(list);
List list2 = PythonTypes.LIST.toJava(dict);
Assert.assertEquals(list.toString(), list2.toString());
}
}

View File

@ -0,0 +1,55 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.eclipse.python4j.Python;
import org.eclipse.python4j.PythonGC;
import org.eclipse.python4j.PythonObject;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
@NotThreadSafe
public class PythonNumpyGCTest {
@Test
public void testGC(){
PythonObject gcModule = Python.importModule("gc");
PythonObject getObjects = gcModule.attr("get_objects");
PythonObject pyObjCount1 = Python.len(getObjects.call());
long objCount1 = pyObjCount1.toLong();
PythonObject pyList = Python.list();
pyList.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10)));
pyList.attr("append").call(1.0);
pyList.attr("append").call(true);
PythonObject pyObjCount2 = Python.len(getObjects.call());
long objCount2 = pyObjCount2.toLong();
long diff = objCount2 - objCount1;
Assert.assertTrue(diff > 2);
try(PythonGC gc = PythonGC.watch()){
PythonObject pyList2 = Python.list();
pyList2.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10)));
pyList2.attr("append").call(1.0);
pyList2.attr("append").call(true);
}
PythonObject pyObjCount3 = Python.len(getObjects.call());
long objCount3 = pyObjCount3.toLong();
diff = objCount3 - objCount2;
Assert.assertTrue(diff <= 2);// 2 objects created during function call
}
}

View File

@ -0,0 +1,22 @@
import org.eclipse.python4j.NumpyArray;
import org.eclipse.python4j.Python;
import org.eclipse.python4j.PythonGC;
import org.eclipse.python4j.PythonObject;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class PythonNumpyImportTest {
@Test
public void testNumpyImport(){
try(PythonGC gc = PythonGC.watch()){
PythonObject np = Python.importModule("numpy");
PythonObject zeros = np.attr("zeros").call(5);
INDArray arr = NumpyArray.INSTANCE.toJava(zeros);
Assert.assertEquals(arr, Nd4j.zeros(DataType.DOUBLE, 5));
}
}
}

View File

@ -0,0 +1,303 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.eclipse.python4j.*;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
@javax.annotation.concurrent.NotThreadSafe
@RunWith(Parameterized.class)
public class PythonNumpyJobTest {
private DataType dataType;
public PythonNumpyJobTest(DataType dataType){
this.dataType = dataType;
}
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
public static DataType[] params() {
return new DataType[]{
DataType.BOOL,
DataType.FLOAT16,
DataType.BFLOAT16,
DataType.FLOAT,
DataType.DOUBLE,
DataType.INT8,
DataType.INT16,
DataType.INT32,
DataType.INT64,
DataType.UINT8,
DataType.UINT16,
DataType.UINT32,
DataType.UINT64
};
}
@Test
public void testNumpyJobBasic(){
PythonContextManager.deleteNonMainContexts();
List<PythonVariable> inputs = new ArrayList<>();
INDArray x = Nd4j.ones(dataType, 2, 3);
INDArray y = Nd4j.zeros(dataType, 2, 3);
INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2));
z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z;
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
inputs.add(new PythonVariable<>("x", arrType, x));
inputs.add(new PythonVariable<>("y", arrType, y));
List<PythonVariable> outputs = new ArrayList<>();
PythonVariable<INDArray> output = new PythonVariable<>("z", arrType);
outputs.add(output);
String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)";
PythonJob job = new PythonJob("job1", code, false);
job.exec(inputs, outputs);
INDArray z2 = output.getValue();
if (dataType == DataType.BFLOAT16){
z2 = z2.castTo(DataType.FLOAT);
}
Assert.assertEquals(z, z2);
}
@Test
public void testNumpyJobReturnAllVariables(){
PythonContextManager.deleteNonMainContexts();
List<PythonVariable> inputs = new ArrayList<>();
INDArray x = Nd4j.ones(dataType, 2, 3);
INDArray y = Nd4j.zeros(dataType, 2, 3);
INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2));
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
inputs.add(new PythonVariable<>("x", arrType, x));
inputs.add(new PythonVariable<>("y", arrType, y));
String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)";
PythonJob job = new PythonJob("job1", code, false);
List<PythonVariable> outputs = job.execAndReturnAllVariables(inputs);
INDArray x2 = (INDArray) outputs.get(0).getValue();
INDArray y2 = (INDArray) outputs.get(1).getValue();
INDArray z2 = (INDArray) outputs.get(2).getValue();
if (dataType == DataType.BFLOAT16){
x = x.castTo(DataType.FLOAT);
y = y.castTo(DataType.FLOAT);
z = z.castTo(DataType.FLOAT);
}
Assert.assertEquals(x, x2);
Assert.assertEquals(y, y2);
Assert.assertEquals(z, z2);
}
@Test
public void testMultipleNumpyJobsParallel(){
PythonContextManager.deleteNonMainContexts();
String code1 =(dataType == DataType.BOOL)?"z = x":"z = x + y";
PythonJob job1 = new PythonJob("job1", code1, false);
String code2 =(dataType == DataType.BOOL)?"z = y":"z = x - y";
PythonJob job2 = new PythonJob("job2", code2, false);
List<PythonVariable> inputs = new ArrayList<>();
INDArray x = Nd4j.ones(dataType, 2, 3);
INDArray y = Nd4j.zeros(dataType, 2, 3);
INDArray z1 = (dataType == DataType.BOOL)?x:x.add(y);
z1 = (dataType == DataType.BFLOAT16)? z1.castTo(DataType.FLOAT): z1;
INDArray z2 = (dataType == DataType.BOOL)?y:x.sub(y);
z2 = (dataType == DataType.BFLOAT16)? z2.castTo(DataType.FLOAT): z2;
PythonType<INDArray> arrType = PythonTypes.get("numpy.ndarray");
inputs.add(new PythonVariable<>("x", arrType, x));
inputs.add(new PythonVariable<>("y", arrType, y));
List<PythonVariable> outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("z", arrType));
job1.exec(inputs, outputs);
assertEquals(z1, outputs.get(0).getValue());
job2.exec(inputs, outputs);
assertEquals(z2, outputs.get(0).getValue());
}
@Test
public synchronized void testNumpyJobSetupRun(){
if (dataType == DataType.BOOL)return;
PythonContextManager.deleteNonMainContexts();
String code = "five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b + five\n"+
" return {'c':c}\n\n";
PythonJob job = new PythonJob("job1", code, true);
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2)));
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
List<PythonVariable> outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
job.exec(inputs, outputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10),
outputs.get(0).getValue());
inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
job.exec(inputs, outputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12),
outputs.get(0).getValue());
}
@Test
public void testNumpyJobSetupRunAndReturnAllVariables(){
if (dataType == DataType.BOOL)return;
PythonContextManager.deleteNonMainContexts();
String code = "five=None\n" +
"c=None\n"+
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" global c\n" +
" c = a + b + five\n";
PythonJob job = new PythonJob("job1", code, true);
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2)));
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
List<PythonVariable> outputs = job.execAndReturnAllVariables(inputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10),
outputs.get(1).getValue());
inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
outputs = job.execAndReturnAllVariables(inputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12),
outputs.get(1).getValue());
}
@Test
public void testMultipleNumpyJobsSetupRunParallel(){
if (dataType == DataType.BOOL)return;
PythonContextManager.deleteNonMainContexts();
String code1 = "five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b + five\n"+
" return {'c':c}\n\n";
PythonJob job1 = new PythonJob("job1", code1, true);
String code2 = "five=None\n" +
"def setup():\n" +
" global five\n"+
" five = 5\n\n" +
"def run(a, b):\n" +
" c = a + b - five\n"+
" return {'c':c}\n\n";
PythonJob job2 = new PythonJob("job2", code2, true);
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2)));
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
List<PythonVariable> outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
job1.exec(inputs, outputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10),
outputs.get(0).getValue());
job2.exec(inputs, outputs);
assertEquals(Nd4j.zeros((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3),
outputs.get(0).getValue());
inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
outputs = new ArrayList<>();
outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE));
job1.exec(inputs, outputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12),
outputs.get(0).getValue());
job2.exec(inputs, outputs);
assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(2),
outputs.get(0).getValue());
}
}

View File

@ -0,0 +1,194 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.eclipse.python4j.*;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@NotThreadSafe
@RunWith(Parameterized.class)
public class PythonNumpyMultiThreadTest {
private DataType dataType;
public PythonNumpyMultiThreadTest(DataType dataType) {
this.dataType = dataType;
}
@Parameterized.Parameters(name = "{index}: Testing with DataType={0}")
public static DataType[] params() {
return new DataType[]{
// DataType.BOOL,
// DataType.FLOAT16,
// DataType.BFLOAT16,
DataType.FLOAT,
DataType.DOUBLE,
// DataType.INT8,
// DataType.INT16,
DataType.INT32,
DataType.INT64,
// DataType.UINT8,
// DataType.UINT16,
// DataType.UINT32,
// DataType.UINT64
};
}
@Test
public void testMultiThreading1() throws Throwable {
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
Runnable runnable = new Runnable() {
@Override
public void run() {
try (PythonGIL gil = PythonGIL.lock()) {
try (PythonGC gc = PythonGC.watch()) {
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
PythonVariable out = new PythonVariable<>("z", NumpyArray.INSTANCE);
String code = "z = x + y";
PythonExecutioner.exec(code, inputs, Collections.singletonList(out));
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), out.getValue());
}
} catch (Throwable e) {
exceptions.add(e);
}
}
};
int numThreads = 10;
Thread[] threads = new Thread[numThreads];
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(runnable);
}
for (int i = 0; i < threads.length; i++) {
threads[i].start();
}
Thread.sleep(100);
for (int i = 0; i < threads.length; i++) {
threads[i].join();
}
if (!exceptions.isEmpty()) {
throw (exceptions.get(0));
}
}
@Test
public void testMultiThreading2() throws Throwable {
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
Runnable runnable = new Runnable() {
@Override
public void run() {
try (PythonGIL gil = PythonGIL.lock()) {
try (PythonGC gc = PythonGC.watch()) {
PythonContextManager.reset();
List<PythonVariable> inputs = new ArrayList<>();
inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3)));
inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4)));
String code = "z = x + y";
List<PythonVariable> outputs = PythonExecutioner.execAndReturnAllVariables(code, inputs);
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(3), outputs.get(0).getValue());
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(4), outputs.get(1).getValue());
Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), outputs.get(2).getValue());
}
} catch (Throwable e) {
exceptions.add(e);
}
}
};
int numThreads = 10;
Thread[] threads = new Thread[numThreads];
for (int i = 0; i < threads.length; i++) {
threads[i] = new Thread(runnable);
}
for (int i = 0; i < threads.length; i++) {
threads[i].start();
}
Thread.sleep(100);
for (int i = 0; i < threads.length; i++) {
threads[i].join();
}
if (!exceptions.isEmpty()) {
throw (exceptions.get(0));
}
}
@Test
public void testMultiThreading3() throws Throwable {
PythonContextManager.deleteNonMainContexts();
String code = "c = a + b";
final PythonJob job = new PythonJob("job1", code, false);
final List<Throwable> exceptions = Collections.synchronizedList(new ArrayList<Throwable>());
class JobThread extends Thread {
private INDArray a, b, c;
public JobThread(INDArray a, INDArray b, INDArray c) {
this.a = a;
this.b = b;
this.c = c;
}
@Override
public void run() {
try {
PythonVariable<INDArray> out = new PythonVariable<>("c", NumpyArray.INSTANCE);
job.exec(Arrays.<PythonVariable>asList(new PythonVariable<>("a", NumpyArray.INSTANCE, a),
new PythonVariable<>("b", NumpyArray.INSTANCE, b)),
Collections.<PythonVariable>singletonList(out));
Assert.assertEquals(c, out.getValue());
} catch (Exception e) {
exceptions.add(e);
}
}
}
int numThreads = 10;
JobThread[] threads = new JobThread[numThreads];
for (int i = 0; i < threads.length; i++) {
threads[i] = new JobThread(Nd4j.zeros(dataType, 2, 3).add(i), Nd4j.zeros(dataType, 2, 3).add(i + 3),
Nd4j.zeros(dataType, 2, 3).add(2 * i + 3));
}
for (int i = 0; i < threads.length; i++) {
threads[i].start();
}
Thread.sleep(100);
for (int i = 0; i < threads.length; i++) {
threads[i].join();
}
if (!exceptions.isEmpty()) {
throw (exceptions.get(0));
}
}
}

View File

@ -0,0 +1,41 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
import org.eclipse.python4j.*;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import javax.annotation.concurrent.NotThreadSafe;
import java.util.ArrayList;
import java.util.List;
@NotThreadSafe
public class PythonNumpyServiceLoaderTest {
@Test
public void testServiceLoader(){
Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.<INDArray>get("numpy.ndarray"));
Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.getPythonTypeForJavaObject(Nd4j.zeros(1)));
}
}