diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java index 68a012b72..c1eff1dce 100644 --- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java +++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java @@ -141,7 +141,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm); sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats()); - System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); +// System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); File f = testDir.newFolder(); DataSetIterator iter = new MnistDataSetIterator(16, true, 12345); int count = 0; @@ -208,10 +208,10 @@ public class GradientSharingTrainingTest extends BaseSparkTest { } INDArray paramsAfter = after.params(); - System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); - System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); - System.out.println(Arrays.toString( - Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); +// System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); +// System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); +// System.out.println(Arrays.toString( +// Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat())); assertNotEquals(paramsBefore, paramsAfter); @@ -235,7 +235,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest { } - @Test + @Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985 public void differentNetsTrainingTest() throws Exception { int batch = 3; diff --git a/python4j/pom.xml b/python4j/pom.xml index 57af8f1bb..1fe50344f 100644 --- a/python4j/pom.xml +++ b/python4j/pom.xml @@ -41,10 +41,14 @@ provided + org.slf4j + slf4j-api + 1.6.6 + ch.qos.logback logback-classic ${logback.version} - test + test junit @@ -62,5 +66,10 @@ jsr305 3.0.2 + + org.slf4j + slf4j-api + 1.6.6 + \ No newline at end of file diff --git a/python4j/python4j-core/pom.xml b/python4j/python4j-core/pom.xml index b429d8272..e74d32392 100644 --- a/python4j/python4j-core/pom.xml +++ b/python4j/python4j-core/pom.xml @@ -39,6 +39,5 @@ cpython-platform ${cpython-platform.version} - \ No newline at end of file diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java index a34d8a239..5675d0864 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java @@ -19,8 +19,10 @@ package org.eclipse.python4j; import javax.lang.model.SourceVersion; +import java.io.Closeable; import java.util.HashSet; import java.util.Set; +import java.util.UUID; import java.util.concurrent.atomic.AtomicBoolean; /** @@ -46,6 +48,31 @@ public class PythonContextManager { init(); } + + public static class Context implements Closeable{ + private final String name; + private final String previous; + private final boolean temp; + public Context(){ + name = "temp_" + UUID.randomUUID().toString().replace("-", "_"); + temp = true; + previous = getCurrentContext(); + setContext(name); + } + public Context(String name){ + this.name = name; + temp = false; + previous = getCurrentContext(); + setContext(name); + } + + @Override + public void close(){ + setContext(previous); + if (temp) deleteContext(name); + } + } + private static void init() { if (init.get()) return; new PythonExecutioner(); @@ -190,6 +217,7 @@ public class PythonContextManager { setContext(tempContext); deleteContext(currContext); setContext(currContext); + deleteContext(tempContext); } /** diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java index 57e1a22ae..542778f76 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java @@ -25,6 +25,7 @@ import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; @@ -42,7 +43,6 @@ public class PythonExecutioner { private final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.eclipse.python4j.path"; private final static String JAVACPP_PYTHON_APPEND_TYPE = "org.eclipse.python4j.path.append"; private final static String DEFAULT_APPEND_TYPE = "before"; - static { init(); } @@ -55,6 +55,11 @@ public class PythonExecutioner { initPythonPath(); PyEval_InitThreads(); Py_InitializeEx(0); + for (PythonType type: PythonTypes.get()){ + type.init(); + } + // Constructors of custom types may contain initialization code that should + // run on the main the thread. } /** @@ -110,6 +115,8 @@ public class PythonExecutioner { getVariables(Arrays.asList(pyVars)); } + + /** * Gets the variable with the given name from the interpreter. * @@ -205,9 +212,9 @@ public class PythonExecutioner { * * @return */ - public static List getAllVariables() { + public static PythonVariables getAllVariables() { PythonGIL.assertThreadSafe(); - List ret = new ArrayList<>(); + PythonVariables ret = new PythonVariables(); PyObject main = PyImport_ImportModule("__main__"); PyObject globals = PyModule_GetDict(main); PyObject keys = PyDict_Keys(globals); @@ -259,7 +266,7 @@ public class PythonExecutioner { * @param inputs * @return */ - public static List execAndReturnAllVariables(String code, List inputs) { + public static PythonVariables execAndReturnAllVariables(String code, List inputs) { setVariables(inputs); simpleExec(getWrappedCode(code)); return getAllVariables(); @@ -271,7 +278,7 @@ public class PythonExecutioner { * @param code * @return */ - public static List execAndReturnAllVariables(String code) { + public static PythonVariables execAndReturnAllVariables(String code) { simpleExec(getWrappedCode(code)); return getAllVariables(); } @@ -279,25 +286,22 @@ public class PythonExecutioner { private static synchronized void initPythonPath() { try { String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY); + + List packagesList = new ArrayList<>(); + packagesList.addAll(Arrays.asList(cachePackages())); + for (PythonType type: PythonTypes.get()){ + packagesList.addAll(Arrays.asList(type.packages())); + } + //// TODO: fix in javacpp + packagesList.add(new File(python.cachePackage(), "site-packages")); + + File[] packages = packagesList.toArray(new File[0]); + if (path == null) { - File[] packages = cachePackages(); - - //// TODO: fix in javacpp - File sitePackagesWindows = new File(python.cachePackage(), "site-packages"); - File[] packages2 = new File[packages.length + 1]; - for (int i = 0; i < packages.length; i++) { - //System.out.println(packages[i].getAbsolutePath()); - packages2[i] = packages[i]; - } - packages2[packages.length] = sitePackagesWindows; - //System.out.println(sitePackagesWindows.getAbsolutePath()); - packages = packages2; - ////////// - Py_SetPath(packages); } else { StringBuffer sb = new StringBuffer(); - File[] packages = cachePackages(); + JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase()); switch (pathAppendValue) { case BEFORE: diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java index 46b3db431..074be294a 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java @@ -90,4 +90,8 @@ public class PythonGIL implements AutoCloseable { PyEval_SaveThread(); PyEval_RestoreThread(mainThreadState); } + + public static boolean locked(){ + return acquired.get(); + } } diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java index cdbb1b81d..0818de890 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java @@ -20,25 +20,29 @@ package org.eclipse.python4j; import lombok.Builder; import lombok.Data; import lombok.NoArgsConstructor; +import lombok.extern.slf4j.Slf4j; import javax.annotation.Nonnull; import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; -@Data -@NoArgsConstructor /** * PythonJob is the right abstraction for executing multiple python scripts * in a multi thread stateful environment. The setup-and-run mode allows your * "setup" code (imports, model loading etc) to be executed only once. */ +@Data +@Slf4j public class PythonJob { + private String code; private String name; private String context; - private boolean setupRunMode; + private final boolean setupRunMode; private PythonObject runF; + private final AtomicBoolean setupDone = new AtomicBoolean(false); static { new PythonExecutioner(); @@ -63,7 +67,6 @@ public class PythonJob { if (PythonContextManager.hasContext(context)) { throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!"); } - if (setupRunMode) setup(); } @@ -71,17 +74,18 @@ public class PythonJob { * Clears all variables in current context and calls setup() */ public void clearState(){ - String context = this.context; - PythonContextManager.setContext("main"); - PythonContextManager.deleteContext(context); - this.context = context; + PythonContextManager.setContext(this.context); + PythonContextManager.reset(); + setupDone.set(false); setup(); } public void setup(){ + if (setupDone.get()) return; try (PythonGIL gil = PythonGIL.lock()) { PythonContextManager.setContext(context); PythonObject runF = PythonExecutioner.getVariable("run"); + if (runF == null || runF.isNone() || !Python.callable(runF)) { PythonExecutioner.exec(code); runF = PythonExecutioner.getVariable("run"); @@ -98,10 +102,12 @@ public class PythonJob { if (!setupF.isNone()) { setupF.call(); } + setupDone.set(true); } } public void exec(List inputs, List outputs) { + if (setupRunMode)setup(); try (PythonGIL gil = PythonGIL.lock()) { try (PythonGC _ = PythonGC.watch()) { PythonContextManager.setContext(context); @@ -139,6 +145,7 @@ public class PythonJob { } public List execAndReturnAllVariables(List inputs){ + if (setupRunMode)setup(); try (PythonGIL gil = PythonGIL.lock()) { try (PythonGC _ = PythonGC.watch()) { PythonContextManager.setContext(context); diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java index f8ec17ed9..69252a5f7 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java @@ -147,7 +147,8 @@ public class PythonObject { } PythonObject pyArgs; PythonObject pyKwargs; - if (args == null) { + + if (args == null || args.isEmpty()) { pyArgs = new PythonObject(PyTuple_New(0)); } else { PythonObject argsList = PythonTypes.convert(args); @@ -158,6 +159,7 @@ public class PythonObject { } else { pyKwargs = PythonTypes.convert(kwargs); } + PythonObject ret = new PythonObject( PyObject_Call( nativePythonObject, @@ -165,7 +167,9 @@ public class PythonObject { pyKwargs == null ? null : pyKwargs.nativePythonObject ) ); + PythonGC.keep(ret); + return ret; } @@ -241,4 +245,48 @@ public class PythonObject { PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject); } + + public PythonObject abs(){ + return new PythonObject(PyNumber_Absolute(nativePythonObject)); + } + public PythonObject add(PythonObject pythonObject){ + return new PythonObject(PyNumber_Add(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject sub(PythonObject pythonObject){ + return new PythonObject(PyNumber_Subtract(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject mod(PythonObject pythonObject){ + return new PythonObject(PyNumber_Divmod(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject mul(PythonObject pythonObject){ + return new PythonObject(PyNumber_Multiply(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject trueDiv(PythonObject pythonObject){ + return new PythonObject(PyNumber_TrueDivide(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject floorDiv(PythonObject pythonObject){ + return new PythonObject(PyNumber_FloorDivide(nativePythonObject, pythonObject.nativePythonObject)); + } + public PythonObject matMul(PythonObject pythonObject){ + return new PythonObject(PyNumber_MatrixMultiply(nativePythonObject, pythonObject.nativePythonObject)); + } + + public void addi(PythonObject pythonObject){ + PyNumber_InPlaceAdd(nativePythonObject, pythonObject.nativePythonObject); + } + public void subi(PythonObject pythonObject){ + PyNumber_InPlaceSubtract(nativePythonObject, pythonObject.nativePythonObject); + } + public void muli(PythonObject pythonObject){ + PyNumber_InPlaceMultiply(nativePythonObject, pythonObject.nativePythonObject); + } + public void trueDivi(PythonObject pythonObject){ + PyNumber_InPlaceTrueDivide(nativePythonObject, pythonObject.nativePythonObject); + } + public void floorDivi(PythonObject pythonObject){ + PyNumber_InPlaceFloorDivide(nativePythonObject, pythonObject.nativePythonObject); + } + public void matMuli(PythonObject pythonObject){ + PyNumber_InPlaceMatrixMultiply(nativePythonObject, pythonObject.nativePythonObject); + } } diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java new file mode 100644 index 000000000..0ca17fb49 --- /dev/null +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java @@ -0,0 +1,127 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.eclipse.python4j; + +import org.apache.commons.io.IOUtils; +import org.bytedeco.javacpp.Loader; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +public class PythonProcess { + private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class); + public static String runAndReturn(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + ProcessBuilder pb = new ProcessBuilder(allArgs); + Process process = pb.start(); + String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8); + process.waitFor(); + return out; + + } + + public static void run(String... arguments)throws IOException, InterruptedException{ + String[] allArgs = new String[arguments.length + 1]; + for (int i = 0; i < arguments.length; i++){ + allArgs[i + 1] = arguments[i]; + } + allArgs[0] = pythonExecutable; + ProcessBuilder pb = new ProcessBuilder(allArgs); + pb.inheritIO().start().waitFor(); + } + public static void pipInstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "install", packageName); + }catch(Exception e){ + throw new PythonException("Error installing package " + packageName, e); + } + + } + + public static void pipInstall(String packageName, String version){ + pipInstall(packageName + "==" + version); + } + + public static void pipUninstall(String packageName) throws PythonException{ + try{ + run("-m", "pip", "uninstall", packageName); + }catch(Exception e){ + throw new PythonException("Error uninstalling package " + packageName, e); + } + + } + public static void pipInstallFromGit(String gitRepoUrl){ + if (!gitRepoUrl.contains("://")){ + gitRepoUrl = "git://" + gitRepoUrl; + } + try{ + run("-m", "pip", "install", "git+", gitRepoUrl); + }catch(Exception e){ + throw new PythonException("Error installing package from " + gitRepoUrl, e); + } + + } + + public static String getPackageVersion(String packageName){ + String out; + try{ + out = runAndReturn("-m", "pip", "show", packageName); + } catch (Exception e){ + throw new PythonException("Error finding version for package " + packageName, e); + } + + if (!out.contains("Version: ")){ + throw new PythonException("Can't find package " + packageName); + } + String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0]; + return pkgVersion; + } + + public static boolean isPackageInstalled(String packageName){ + try{ + String out = runAndReturn("-m", "pip", "show", packageName); + return !out.isEmpty(); + }catch (Exception e){ + throw new PythonException("Error checking if package is installed: " +packageName, e); + } + + } + + public static void pipInstallFromRequirementsTxt(String path){ + try{ + run("-m", "pip", "install","-r", path); + }catch (Exception e){ + throw new PythonException("Error installing packages from " + path, e); + } + } + + public static void pipInstallFromSetupScript(String path, boolean inplace){ + + try{ + run(path, inplace?"develop":"install"); + }catch (Exception e){ + throw new PythonException("Error installing package from " + path, e); + } + + } + +} \ No newline at end of file diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java index b4806aa37..47b725cd5 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java @@ -17,6 +17,8 @@ package org.eclipse.python4j; +import java.io.File; + public abstract class PythonType { private final String name; @@ -43,5 +45,25 @@ public abstract class PythonType { return name; } + @Override + public boolean equals(Object obj){ + if (!(obj instanceof PythonType)){ + return false; + } + PythonType other = (PythonType)obj; + return this.getClass().equals(other.getClass()) && this.name.equals(other.name); + } + + public PythonObject pythonType(){ + return null; + } + + public File[] packages(){ + return new File[0]; + } + + public void init(){ //not to be called from constructor + + } } diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java index 0dc20f712..cd7ac7d7c 100644 --- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java @@ -18,7 +18,16 @@ package org.eclipse.python4j; import org.bytedeco.cpython.PyObject; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Loader; +import org.bytedeco.javacpp.Pointer; +import sun.misc.Unsafe; +import sun.nio.ch.DirectBuffer; +import java.lang.reflect.Field; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import java.util.*; import static org.bytedeco.cpython.global.python.*; @@ -28,7 +37,7 @@ public class PythonTypes { private static List getPrimitiveTypes() { - return Arrays.asList(STR, INT, FLOAT, BOOL); + return Arrays.asList(STR, INT, FLOAT, BOOL, MEMORYVIEW); } private static List getCollectionTypes() { @@ -36,8 +45,13 @@ public class PythonTypes { } private static List getExternalTypes() { - //TODO service loader - return new ArrayList<>(); + List ret = new ArrayList<>(); + ServiceLoader sl = ServiceLoader.load(PythonType.class); + Iterator iter = sl.iterator(); + while (iter.hasNext()) { + ret.add(iter.next()); + } + return ret; } public static List get() { @@ -48,15 +62,17 @@ public class PythonTypes { return ret; } - public static PythonType get(String name) { + public static PythonType get(String name) { for (PythonType pt : get()) { if (pt.getName().equals(name)) { // TODO use map instead? return pt; } + } throw new PythonException("Unknown python type: " + name); } + public static PythonType getPythonTypeForJavaObject(Object javaObject) { for (PythonType pt : get()) { if (pt.accepts(javaObject)) { @@ -66,7 +82,7 @@ public class PythonTypes { throw new PythonException("Unable to find python type for java type: " + javaObject.getClass()); } - public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) { + public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) { PyObject pyType = PyObject_Type(pythonObject.getNativePythonObject()); try { String pyTypeStr = PythonTypes.STR.toJava(new PythonObject(pyType, false)); @@ -75,6 +91,14 @@ public class PythonTypes { String pyTypeStr2 = ""; if (pyTypeStr.equals(pyTypeStr2)) { return pt; + } else { + try (PythonGC gc = PythonGC.watch()) { + PythonObject pyType2 = pt.pythonType(); + if (pyType2 != null && Python.isinstance(pythonObject, pyType2)) { + return pt; + } + } + } } throw new PythonException("Unable to find converter for python object of type " + pyTypeStr); @@ -212,12 +236,49 @@ public class PythonTypes { public static final PythonType LIST = new PythonType("list", List.class) { + @Override + public boolean accepts(Object javaObject) { + return (javaObject instanceof List || javaObject.getClass().isArray()); + } + @Override public List adapt(Object javaObject) { if (javaObject instanceof List) { return (List) javaObject; - } else if (javaObject instanceof Object[]) { - return Arrays.asList((Object[]) javaObject); + } else if (javaObject.getClass().isArray()) { + List ret = new ArrayList<>(); + if (javaObject instanceof Object[]) { + Object[] arr = (Object[]) javaObject; + return new ArrayList<>(Arrays.asList(arr)); + } else if (javaObject instanceof short[]) { + short[] arr = (short[]) javaObject; + for (short x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof int[]) { + int[] arr = (int[]) javaObject; + for (int x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof long[]) { + long[] arr = (long[]) javaObject; + for (long x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof float[]) { + float[] arr = (float[]) javaObject; + for (float x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof double[]) { + double[] arr = (double[]) javaObject; + for (double x : arr) ret.add(x); + return ret; + } else if (javaObject instanceof boolean[]) { + boolean[] arr = (boolean[]) javaObject; + for (boolean x : arr) ret.add(x); + return ret; + } else { + throw new PythonException("Unsupported array type: " + javaObject.getClass().toString()); + } + + } else { throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to List"); } @@ -327,7 +388,13 @@ public class PythonTypes { } Object v = javaObject.get(k); PythonObject pyVal; - pyVal = PythonTypes.convert(v); + if (v instanceof PythonObject) { + pyVal = (PythonObject) v; + } else if (v instanceof PyObject) { + pyVal = new PythonObject((PyObject) v); + } else { + pyVal = PythonTypes.convert(v); + } int errCode = PyDict_SetItem(pyDict, pyKey.getNativePythonObject(), pyVal.getNativePythonObject()); if (errCode != 0) { String keyStr = pyKey.toString(); @@ -341,4 +408,85 @@ public class PythonTypes { return new PythonObject(pyDict); } }; + + + public static final PythonType MEMORYVIEW = new PythonType("memoryview", BytePointer.class) { + @Override + public BytePointer toJava(PythonObject pythonObject) { + try (PythonGC gc = PythonGC.watch()) { + if (!(Python.isinstance(pythonObject, Python.memoryviewType()))) { + throw new PythonException("Expected memoryview. Received: " + pythonObject); + } + PythonObject pySize = Python.len(pythonObject); + PythonObject ctypes = Python.importModule("ctypes"); + PythonObject charType = ctypes.attr("c_char"); + PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(), + pySize.getNativePythonObject())); + PythonObject fromBuffer = charArrayType.attr("from_buffer"); + if (pythonObject.attr("readonly").toBoolean()) { + pythonObject = Python.bytearray(pythonObject); + } + PythonObject arr = fromBuffer.call(pythonObject); + PythonObject cast = ctypes.attr("cast"); + PythonObject voidPtrType = ctypes.attr("c_void_p"); + PythonObject voidPtr = cast.call(arr, voidPtrType); + long address = voidPtr.attr("value").toLong(); + long size = pySize.toLong(); + try { + Field addressField = Buffer.class.getDeclaredField("address"); + addressField.setAccessible(true); + Field capacityField = Buffer.class.getDeclaredField("capacity"); + capacityField.setAccessible(true); + ByteBuffer buff = ByteBuffer.allocateDirect(0).order(ByteOrder.nativeOrder()); + addressField.setLong(buff, address); + capacityField.setInt(buff, (int) size); + BytePointer ret = new BytePointer(buff); + ret.limit(size); + return ret; + + } catch (Exception e) { + throw new RuntimeException(e); + } + + } + } + + @Override + public PythonObject toPython(BytePointer javaObject) { + long address = javaObject.address(); + long size = javaObject.limit(); + try (PythonGC gc = PythonGC.watch()) { + PythonObject ctypes = Python.importModule("ctypes"); + PythonObject charType = ctypes.attr("c_char"); + PythonObject pySize = new PythonObject(size); + PythonObject charArrayType = new PythonObject(PyNumber_Multiply(charType.getNativePythonObject(), + pySize.getNativePythonObject())); + PythonObject fromAddress = charArrayType.attr("from_address"); + PythonObject arr = fromAddress.call(new PythonObject(address)); + PythonObject memoryView = Python.memoryview(arr).attr("cast").call("b"); + PythonGC.keep(memoryView); + return memoryView; + } + + } + + @Override + public boolean accepts(Object javaObject) { + return javaObject instanceof Pointer || javaObject instanceof DirectBuffer; + } + + @Override + public BytePointer adapt(Object javaObject) { + if (javaObject instanceof BytePointer) { + return (BytePointer) javaObject; + } else if (javaObject instanceof Pointer) { + return new BytePointer((Pointer) javaObject); + } else if (javaObject instanceof DirectBuffer) { + return new BytePointer((ByteBuffer) javaObject); + } else { + throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to BytePointer"); + } + } + }; + } diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariables.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariables.java new file mode 100644 index 000000000..32ae0b2f5 --- /dev/null +++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonVariables.java @@ -0,0 +1,47 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.eclipse.python4j; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Some syntax sugar for lookup by name + */ +public class PythonVariables extends ArrayList { + public PythonVariable get(String variableName) { + for (PythonVariable pyVar: this){ + if (pyVar.getName().equals(variableName)){ + return pyVar; + } + } + return null; + } + + public boolean add(String variableName, PythonType variableType, Object value){ + return this.add(new PythonVariable<>(variableName, variableType, value)); + } + + public PythonVariables(PythonVariable... variables){ + this(Arrays.asList(variables)); + } + public PythonVariables(List list){ + super(); + addAll(list); + } +} diff --git a/python4j/python4j-core/src/test/java/PythonBufferTest.java b/python4j/python4j-core/src/test/java/PythonBufferTest.java new file mode 100644 index 000000000..c59b86c15 --- /dev/null +++ b/python4j/python4j-core/src/test/java/PythonBufferTest.java @@ -0,0 +1,113 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Loader; +import org.eclipse.python4j.*; +import org.junit.Assert; +import org.junit.Test; +import sun.nio.ch.DirectBuffer; + +import javax.annotation.concurrent.NotThreadSafe; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.util.*; + +@NotThreadSafe +public class PythonBufferTest { + + @Test + public void testBuffer() { + ByteBuffer buff = ByteBuffer.allocateDirect(3); + buff.put((byte) 97); + buff.put((byte) 98); + buff.put((byte) 99); + buff.rewind(); + + BytePointer bp = new BytePointer(buff); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("buff", PythonTypes.MEMORYVIEW, buff)); + + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); + outputs.add(new PythonVariable<>("s2", PythonTypes.STR)); + + String code = "s1 = ''.join(chr(c) for c in buff)\nbuff[2] += 2\ns2 = ''.join(chr(c) for c in buff)"; + + PythonExecutioner.exec(code, inputs, outputs); + Assert.assertEquals("abc", outputs.get(0).getValue()); + Assert.assertEquals("abe", outputs.get(1).getValue()); + Assert.assertEquals(101, buff.get(2)); + + } + @Test + public void testBuffer2() { + ByteBuffer buff = ByteBuffer.allocateDirect(3); + buff.put((byte) 97); + buff.put((byte) 98); + buff.put((byte) 99); + buff.rewind(); + + BytePointer bp = new BytePointer(buff); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("buff", PythonTypes.MEMORYVIEW, bp)); + + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); + outputs.add(new PythonVariable<>("s2", PythonTypes.STR)); + + String code = "s1 = ''.join(chr(c) for c in buff)\nbuff[2] += 2\ns2 = ''.join(chr(c) for c in buff)"; + + PythonExecutioner.exec(code, inputs, outputs); + Assert.assertEquals("abc", outputs.get(0).getValue()); + Assert.assertEquals("abe", outputs.get(1).getValue()); + Assert.assertEquals(101, buff.get(2)); + + } + + @Test + public void testBuffer3() { + ByteBuffer buff = ByteBuffer.allocateDirect(3); + buff.put((byte) 97); + buff.put((byte) 98); + buff.put((byte) 99); + buff.rewind(); + + BytePointer bp = new BytePointer(buff); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("buff", PythonTypes.MEMORYVIEW, bp)); + + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); + outputs.add(new PythonVariable<>("s2", PythonTypes.STR)); + outputs.add(new PythonVariable<>("buff2", PythonTypes.MEMORYVIEW)); + String code = "s1 = ''.join(chr(c) for c in buff)\nbuff[2] += 2\ns2 = ''.join(chr(c) for c in buff)\nbuff2=buff[1:]"; + PythonExecutioner.exec(code, inputs, outputs); + + Assert.assertEquals("abc", outputs.get(0).getValue()); + Assert.assertEquals("abe", outputs.get(1).getValue()); + Assert.assertEquals(101, buff.get(2)); + BytePointer outBuffer = (BytePointer) outputs.get(2).getValue(); + Assert.assertEquals(2, outBuffer.capacity()); + Assert.assertEquals((byte)98, outBuffer.get(0)); + Assert.assertEquals((byte)101, outBuffer.get(1)); + + } +} \ No newline at end of file diff --git a/python4j/python4j-core/src/test/java/PythonGCTest.java b/python4j/python4j-core/src/test/java/PythonGCTest.java index f8c6ecba5..80b2e7f3c 100644 --- a/python4j/python4j-core/src/test/java/PythonGCTest.java +++ b/python4j/python4j-core/src/test/java/PythonGCTest.java @@ -49,6 +49,6 @@ public class PythonGCTest { PythonObject pyObjCount3 = Python.len(getObjects.call()); long objCount3 = pyObjCount3.toLong(); diff = objCount3 - objCount2; - Assert.assertEquals(2, diff);// 2 objects created during function call + Assert.assertTrue(diff <= 2);// 2 objects created during function call } } diff --git a/python4j/python4j-core/src/test/java/PythonJobTest.java b/python4j/python4j-core/src/test/java/PythonJobTest.java index 016045a25..b0f4233c9 100644 --- a/python4j/python4j-core/src/test/java/PythonJobTest.java +++ b/python4j/python4j-core/src/test/java/PythonJobTest.java @@ -30,7 +30,7 @@ import static org.junit.Assert.assertEquals; public class PythonJobTest { @Test - public void testPythonJobBasic() throws Exception{ + public void testPythonJobBasic(){ PythonContextManager.deleteNonMainContexts(); String code = "c = a + b"; @@ -65,7 +65,7 @@ public class PythonJobTest { } @Test - public void testPythonJobReturnAllVariables()throws Exception{ + public void testPythonJobReturnAllVariables(){ PythonContextManager.deleteNonMainContexts(); String code = "c = a + b"; @@ -101,7 +101,7 @@ public class PythonJobTest { @Test - public void testMultiplePythonJobsParallel()throws Exception{ + public void testMultiplePythonJobsParallel(){ PythonContextManager.deleteNonMainContexts(); String code1 = "c = a + b"; PythonJob job1 = new PythonJob("job1", code1, false); @@ -150,7 +150,7 @@ public class PythonJobTest { @Test - public void testPythonJobSetupRun()throws Exception{ + public void testPythonJobSetupRun(){ PythonContextManager.deleteNonMainContexts(); String code = "five=None\n" + @@ -189,7 +189,7 @@ public class PythonJobTest { } @Test - public void testPythonJobSetupRunAndReturnAllVariables()throws Exception{ + public void testPythonJobSetupRunAndReturnAllVariables(){ PythonContextManager.deleteNonMainContexts(); String code = "five=None\n" + "c=None\n"+ @@ -225,7 +225,7 @@ public class PythonJobTest { } @Test - public void testMultiplePythonJobsSetupRunParallel()throws Exception{ + public void testMultiplePythonJobsSetupRunParallel(){ PythonContextManager.deleteNonMainContexts(); String code1 = "five=None\n" + diff --git a/python4j/python4j-numpy/pom.xml b/python4j/python4j-numpy/pom.xml index 527a9343f..bcce739ce 100644 --- a/python4j/python4j-numpy/pom.xml +++ b/python4j/python4j-numpy/pom.xml @@ -28,15 +28,50 @@ ${nd4j.version} test + + org.eclipse + python4j-core + 1.0.0-SNAPSHOT + test-nd4j-native + + + org.nd4j + nd4j-native + ${nd4j.version} + test + + + org.deeplearning4j + dl4j-test-resources + ${nd4j.version} + test + + + test-nd4j-cuda-10.2 + + + org.nd4j + nd4j-cuda-10.1 + ${nd4j.version} + test + + + org.deeplearning4j + dl4j-test-resources + ${nd4j.version} + test + + + \ No newline at end of file diff --git a/python4j/python4j-numpy/src/main/java/org/eclipse/python4j/NumpyArray.java b/python4j/python4j-numpy/src/main/java/org/eclipse/python4j/NumpyArray.java new file mode 100644 index 000000000..66fb76d23 --- /dev/null +++ b/python4j/python4j-numpy/src/main/java/org/eclipse/python4j/NumpyArray.java @@ -0,0 +1,303 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.eclipse.python4j; + +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.lang3.ArrayUtils; +import org.bytedeco.cpython.PyObject; +import org.bytedeco.cpython.PyTypeObject; +import org.bytedeco.javacpp.Pointer; +import org.bytedeco.javacpp.SizeTPointer; +import org.bytedeco.numpy.PyArrayObject; +import org.bytedeco.numpy.global.numpy; +import org.nd4j.linalg.api.buffer.BaseDataBuffer; +import org.nd4j.linalg.api.buffer.DataBuffer; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.concurrency.AffinityManager; +import org.nd4j.linalg.api.memory.MemoryWorkspace; +import org.nd4j.linalg.api.memory.MemoryWorkspaceManager; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.shape.Shape; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.nativeblas.NativeOps; +import org.nd4j.nativeblas.NativeOpsHolder; + +import java.io.File; +import java.util.*; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.bytedeco.cpython.global.python.*; +import static org.bytedeco.cpython.global.python.Py_DecRef; +import static org.bytedeco.numpy.global.numpy.*; +import static org.bytedeco.numpy.global.numpy.NPY_ARRAY_CARRAY; +import static org.bytedeco.numpy.global.numpy.PyArray_Type; + +@Slf4j +public class NumpyArray extends PythonType { + + public static final NumpyArray INSTANCE; + private static final AtomicBoolean init = new AtomicBoolean(false); + private static final Map cache = new HashMap<>(); + + static { + new PythonExecutioner(); + INSTANCE = new NumpyArray(); + } + + @Override + public File[] packages(){ + try{ + return new File[]{numpy.cachePackage()}; + }catch(Exception e){ + throw new PythonException(e); + } + + } + + public synchronized void init() { + if (init.get()) return; + init.set(true); + if (PythonGIL.locked()) { + throw new PythonException("Can not initialize numpy - GIL already acquired."); + } + int err = numpy._import_array(); + if (err < 0){ + System.out.println("Numpy import failed!"); + throw new PythonException("Numpy import failed!"); + } + } + + public NumpyArray() { + super("numpy.ndarray", INDArray.class); + + } + + @Override + public INDArray toJava(PythonObject pythonObject) { + log.info("Converting PythonObject to INDArray..."); + PyObject np = PyImport_ImportModule("numpy"); + PyObject ndarray = PyObject_GetAttrString(np, "ndarray"); + if (PyObject_IsInstance(pythonObject.getNativePythonObject(), ndarray) != 1) { + Py_DecRef(ndarray); + Py_DecRef(np); + throw new PythonException("Object is not a numpy array! Use Python.ndarray() to convert object to a numpy array."); + } + Py_DecRef(ndarray); + Py_DecRef(np); + PyArrayObject npArr = new PyArrayObject(pythonObject.getNativePythonObject()); + long[] shape = new long[PyArray_NDIM(npArr)]; + SizeTPointer shapePtr = PyArray_SHAPE(npArr); + if (shapePtr != null) + shapePtr.get(shape, 0, shape.length); + long[] strides = new long[shape.length]; + SizeTPointer stridesPtr = PyArray_STRIDES(npArr); + if (stridesPtr != null) + stridesPtr.get(strides, 0, strides.length); + int npdtype = PyArray_TYPE(npArr); + + DataType dtype; + switch (npdtype) { + case NPY_DOUBLE: + dtype = DataType.DOUBLE; + break; + case NPY_FLOAT: + dtype = DataType.FLOAT; + break; + case NPY_SHORT: + dtype = DataType.SHORT; + break; + case NPY_INT: + dtype = DataType.INT32; + break; + case NPY_LONG: + dtype = DataType.INT64; + break; + case NPY_UINT: + dtype = DataType.UINT32; + break; + case NPY_BYTE: + dtype = DataType.INT8; + break; + case NPY_UBYTE: + dtype = DataType.UINT8; + break; + case NPY_BOOL: + dtype = DataType.BOOL; + break; + case NPY_HALF: + dtype = DataType.FLOAT16; + break; + case NPY_LONGLONG: + dtype = DataType.INT64; + break; + case NPY_USHORT: + dtype = DataType.UINT16; + break; + case NPY_ULONG: + case NPY_ULONGLONG: + dtype = DataType.UINT64; + break; + default: + throw new PythonException("Unsupported array data type: " + npdtype); + } + long size = 1; + for (int i = 0; i < shape.length; size *= shape[i++]) ; + + INDArray ret; + long address = PyArray_DATA(npArr).address(); + String key = address + "_" + size + "_" + dtype; + DataBuffer buff = cache.get(key); + if (buff == null) { + try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { + Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address); + ptr = ptr.limit(size); + ptr = ptr.capacity(size); + buff = Nd4j.createBuffer(ptr, size, dtype); + cache.put(key, buff); + } + } + int elemSize = buff.getElementSize(); + long[] nd4jStrides = new long[strides.length]; + for (int i = 0; i < strides.length; i++) { + nd4jStrides[i] = strides[i] / elemSize; + } + ret = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype); + Nd4j.getAffinityManager().tagLocation(ret, AffinityManager.Location.HOST); + log.info("Done."); + return ret; + + + } + + @Override + public PythonObject toPython(INDArray indArray) { + log.info("Converting INDArray to PythonObject..."); + DataType dataType = indArray.dataType(); + DataBuffer buff = indArray.data(); + String key = buff.pointer().address() + "_" + buff.length() + "_" + dataType; + cache.put(key, buff); + int numpyType; + String ctype; + switch (dataType) { + case DOUBLE: + numpyType = NPY_DOUBLE; + ctype = "c_double"; + break; + case FLOAT: + case BFLOAT16: + numpyType = NPY_FLOAT; + ctype = "c_float"; + break; + case SHORT: + numpyType = NPY_SHORT; + ctype = "c_short"; + break; + case INT: + numpyType = NPY_INT; + ctype = "c_int"; + break; + case LONG: + numpyType = NPY_INT64; + ctype = "c_int64"; + break; + case UINT16: + numpyType = NPY_USHORT; + ctype = "c_uint16"; + break; + case UINT32: + numpyType = NPY_UINT; + ctype = "c_uint"; + break; + case UINT64: + numpyType = NPY_UINT64; + ctype = "c_uint64"; + break; + case BOOL: + numpyType = NPY_BOOL; + ctype = "c_bool"; + break; + case BYTE: + numpyType = NPY_BYTE; + ctype = "c_byte"; + break; + case UBYTE: + numpyType = NPY_UBYTE; + ctype = "c_ubyte"; + break; + case HALF: + numpyType = NPY_HALF; + ctype = "c_short"; + break; + default: + throw new RuntimeException("Unsupported dtype: " + dataType); + } + + long[] shape = indArray.shape(); + INDArray inputArray = indArray; + if (dataType == DataType.BFLOAT16) { + log.warn("Creating copy of array as bfloat16 is not supported by numpy."); + inputArray = indArray.castTo(DataType.FLOAT); + } + + //Sync to host memory in the case of CUDA, before passing the host memory pointer to Python + + Nd4j.getAffinityManager().ensureLocation(inputArray, AffinityManager.Location.HOST); + + // PyArray_Type() call causes jvm crash in linux cpu if GIL is acquired by non main thread. + // Using Interpreter for now: + +// try(PythonContextManager.Context context = new PythonContextManager.Context("__np_array_converter")){ +// log.info("Stringing exec..."); +// String code = "import ctypes\nimport numpy as np\n" + +// "cArr = (ctypes." + ctype + "*" + indArray.length() + ")"+ +// ".from_address(" + indArray.data().pointer().address() + ")\n"+ +// "npArr = np.frombuffer(cArr, dtype=" + ((numpyType == NPY_HALF) ? "'half'" : "ctypes." + ctype)+ +// ").reshape(" + Arrays.toString(indArray.shape()) + ")"; +// PythonExecutioner.exec(code); +// log.info("exec done."); +// PythonObject ret = PythonExecutioner.getVariable("npArr"); +// Py_IncRef(ret.getNativePythonObject()); +// return ret; +// +// } + log.info("NUMPY: PyArray_Type()"); + PyTypeObject pyTypeObject = PyArray_Type(); + + + log.info("NUMPY: PyArray_New()"); + PyObject npArr = PyArray_New(pyTypeObject, shape.length, new SizeTPointer(shape), + numpyType, null, + inputArray.data().addressPointer(), + 0, NPY_ARRAY_CARRAY, null); + log.info("Done."); + return new PythonObject(npArr); + } + + @Override + public boolean accepts(Object javaObject) { + return javaObject instanceof INDArray; + } + + @Override + public INDArray adapt(Object javaObject) { + if (javaObject instanceof INDArray) { + return (INDArray) javaObject; + } + throw new PythonException("Cannot cast object of type " + javaObject.getClass().getName() + " to INDArray"); + } +} diff --git a/python4j/python4j-numpy/src/main/resources/META-INF/services/org.eclipse.python4j.PythonType b/python4j/python4j-numpy/src/main/resources/META-INF/services/org.eclipse.python4j.PythonType new file mode 100644 index 000000000..ae4d4640b --- /dev/null +++ b/python4j/python4j-numpy/src/main/resources/META-INF/services/org.eclipse.python4j.PythonType @@ -0,0 +1 @@ +org.eclipse.python4j.NumpyArray \ No newline at end of file diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java new file mode 100644 index 000000000..b7bd838b5 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyBasicTest.java @@ -0,0 +1,170 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +import org.eclipse.python4j.*; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.concurrency.AffinityManager; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.nativeblas.OpaqueDataBuffer; + +import javax.annotation.concurrent.NotThreadSafe; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +@NotThreadSafe +@RunWith(Parameterized.class) +public class PythonNumpyBasicTest { + private DataType dataType; + private long[] shape; + + public PythonNumpyBasicTest(DataType dataType, long[] shape, String dummyArg) { + this.dataType = dataType; + this.shape = shape; + } + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}, shape={2}") + public static Collection params() { + DataType[] types = new DataType[] { + DataType.BOOL, + DataType.FLOAT16, + DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, + DataType.INT8, + DataType.INT16, + DataType.INT32, + DataType.INT64, + DataType.UINT8, + DataType.UINT16, + DataType.UINT32, + DataType.UINT64 + }; + + long[][] shapes = new long[][]{ + new long[]{2, 3}, + new long[]{3}, + new long[]{1}, + new long[]{} // scalar + }; + + + List ret = new ArrayList<>(); + for (DataType type: types){ + for (long[] shape: shapes){ + ret.add(new Object[]{type, shape, Arrays.toString(shape)}); + } + } + return ret; + } + + @Test + public void testConversion(){ + INDArray arr = Nd4j.zeros(dataType, shape); + PythonObject npArr = PythonTypes.convert(arr); + INDArray arr2 = PythonTypes.getPythonTypeForPythonObject(npArr).toJava(npArr); + if (dataType == DataType.BFLOAT16){ + arr = arr.castTo(DataType.FLOAT); + } + Assert.assertEquals(arr,arr2); + } + + + @Test + public void testExecution(){ + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, shape); + INDArray y = Nd4j.zeros(dataType, shape); + INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2)); + z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z; + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + List outputs = new ArrayList<>(); + PythonVariable output = new PythonVariable<>("z", arrType); + outputs.add(output); + String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)"; + if (shape.length == 0){ // scalar special case + code += "\nimport numpy as np\nz = np.asarray(float(z), dtype=x.dtype)"; + } + PythonExecutioner.exec(code, inputs, outputs); + INDArray z2 = output.getValue(); + + Assert.assertEquals(z.dataType(), z2.dataType()); + Assert.assertEquals(z, z2); + + } + + + @Test + public void testInplaceExecution(){ + if (dataType == DataType.BOOL || dataType == DataType.BFLOAT16)return; + if (shape.length == 0) return; + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, shape); + INDArray y = Nd4j.zeros(dataType, shape); + INDArray z = x.mul(y.add(2)); + // Nd4j.getAffinityManager().ensureLocation(z, AffinityManager.Location.HOST); + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + List outputs = new ArrayList<>(); + PythonVariable output = new PythonVariable<>("x", arrType); + outputs.add(output); + String code = "x *= y + 2"; + PythonExecutioner.exec(code, inputs, outputs); + INDArray z2 = output.getValue(); + Assert.assertEquals(x.dataType(), z2.dataType()); + Assert.assertEquals(z.dataType(), z2.dataType()); + Assert.assertEquals(x, z2); + Assert.assertEquals(z, z2); + Assert.assertEquals(x.data().pointer().address(), z2.data().pointer().address()); + if("CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){ + Assert.assertEquals(getDeviceAddress(x), getDeviceAddress(z2)); + } + + + } + private static long getDeviceAddress(INDArray array){ + if(!"CUDA".equalsIgnoreCase(Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"))){ + throw new IllegalStateException("Cannot ge device pointer for non-CUDA device"); + } + + //Use reflection here as OpaqueDataBuffer is only available on BaseCudaDataBuffer and BaseCpuDataBuffer - not DataBuffer/BaseDataBuffer + // due to it being defined in nd4j-native-api, not nd4j-api + try { + Class c = Class.forName("org.nd4j.linalg.jcublas.buffer.BaseCudaDataBuffer"); + Method m = c.getMethod("getOpaqueDataBuffer"); + OpaqueDataBuffer db = (OpaqueDataBuffer) m.invoke(array.data()); + long address = db.specialBuffer().address(); + return address; + } catch (Throwable t){ + throw new RuntimeException("Error getting OpaqueDataBuffer", t); + } + } + + + + +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java new file mode 100644 index 000000000..99a050f63 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyCollectionsTest.java @@ -0,0 +1,96 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +import org.eclipse.python4j.PythonException; +import org.eclipse.python4j.PythonObject; +import org.eclipse.python4j.PythonTypes; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.factory.Nd4j; + +import javax.annotation.concurrent.NotThreadSafe; +import java.util.*; + + +@NotThreadSafe +@RunWith(Parameterized.class) +public class PythonNumpyCollectionsTest { + private DataType dataType; + + public PythonNumpyCollectionsTest(DataType dataType){ + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}") + public static DataType[] params() { + return new DataType[]{ + DataType.BOOL, + DataType.FLOAT16, + //DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, + DataType.INT8, + DataType.INT16, + DataType.INT32, + DataType.INT64, + DataType.UINT8, + DataType.UINT16, + DataType.UINT32, + DataType.UINT64 + }; + } + @Test + public void testPythonDictFromMap() throws PythonException { + Map map = new HashMap(); + map.put("a", 1); + map.put(1, "a"); + map.put("arr", Nd4j.ones(dataType, 2, 3)); + map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType,3,2))); + Map innerMap = new HashMap(); + innerMap.put("b", 2); + innerMap.put(2, "b"); + innerMap.put(5, Nd4j.ones(dataType, 5)); + map.put("innermap", innerMap); + map.put("list2", Arrays.asList(4, "5", innerMap, false, true)); + PythonObject dict = PythonTypes.convert(map); + Map map2 = PythonTypes.DICT.toJava(dict); + Assert.assertEquals(map.toString(), map2.toString()); + } + + @Test + public void testPythonListFromList() throws PythonException{ + List list = new ArrayList<>(); + list.add(1); + list.add("2"); + list.add(Nd4j.ones(dataType, 2, 3)); + list.add(Arrays.asList("a", + Nd4j.ones(dataType, 1, 2),1.0, 2f, 10, true, false, + Nd4j.zeros(dataType, 3, 2))); + Map map = new HashMap(); + map.put("a", 1); + map.put(1, "a"); + map.put(5, Nd4j.ones(dataType,4, 5)); + map.put("list1", Arrays.asList(1, 2.0, 3, 4f, Nd4j.zeros(dataType, 3, 1))); + list.add(map); + PythonObject dict = PythonTypes.convert(list); + List list2 = PythonTypes.LIST.toJava(dict); + Assert.assertEquals(list.toString(), list2.toString()); + } +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java new file mode 100644 index 000000000..d1c5ba761 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyGCTest.java @@ -0,0 +1,55 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import org.eclipse.python4j.Python; +import org.eclipse.python4j.PythonGC; +import org.eclipse.python4j.PythonObject; +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.factory.Nd4j; + +import javax.annotation.concurrent.NotThreadSafe; + + +@NotThreadSafe +public class PythonNumpyGCTest { + + @Test + public void testGC(){ + PythonObject gcModule = Python.importModule("gc"); + PythonObject getObjects = gcModule.attr("get_objects"); + PythonObject pyObjCount1 = Python.len(getObjects.call()); + long objCount1 = pyObjCount1.toLong(); + PythonObject pyList = Python.list(); + pyList.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10))); + pyList.attr("append").call(1.0); + pyList.attr("append").call(true); + PythonObject pyObjCount2 = Python.len(getObjects.call()); + long objCount2 = pyObjCount2.toLong(); + long diff = objCount2 - objCount1; + Assert.assertTrue(diff > 2); + try(PythonGC gc = PythonGC.watch()){ + PythonObject pyList2 = Python.list(); + pyList2.attr("append").call(new PythonObject(Nd4j.linspace(1, 10, 10))); + pyList2.attr("append").call(1.0); + pyList2.attr("append").call(true); + } + PythonObject pyObjCount3 = Python.len(getObjects.call()); + long objCount3 = pyObjCount3.toLong(); + diff = objCount3 - objCount2; + Assert.assertTrue(diff <= 2);// 2 objects created during function call + } +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java new file mode 100644 index 000000000..580f8643b --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyImportTest.java @@ -0,0 +1,22 @@ +import org.eclipse.python4j.NumpyArray; +import org.eclipse.python4j.Python; +import org.eclipse.python4j.PythonGC; +import org.eclipse.python4j.PythonObject; +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +public class PythonNumpyImportTest { + + @Test + public void testNumpyImport(){ + try(PythonGC gc = PythonGC.watch()){ + PythonObject np = Python.importModule("numpy"); + PythonObject zeros = np.attr("zeros").call(5); + INDArray arr = NumpyArray.INSTANCE.toJava(zeros); + Assert.assertEquals(arr, Nd4j.zeros(DataType.DOUBLE, 5)); + } + } +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java new file mode 100644 index 000000000..399b87fb1 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyJobTest.java @@ -0,0 +1,303 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import org.eclipse.python4j.*; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.Assert.assertEquals; + + +@javax.annotation.concurrent.NotThreadSafe +@RunWith(Parameterized.class) +public class PythonNumpyJobTest { + private DataType dataType; + + public PythonNumpyJobTest(DataType dataType){ + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}") + public static DataType[] params() { + return new DataType[]{ + DataType.BOOL, + DataType.FLOAT16, + DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, + DataType.INT8, + DataType.INT16, + DataType.INT32, + DataType.INT64, + DataType.UINT8, + DataType.UINT16, + DataType.UINT32, + DataType.UINT64 + }; + } + + @Test + public void testNumpyJobBasic(){ + PythonContextManager.deleteNonMainContexts(); + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, 2, 3); + INDArray y = Nd4j.zeros(dataType, 2, 3); + INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2)); + z = (dataType == DataType.BFLOAT16)? z.castTo(DataType.FLOAT): z; + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + List outputs = new ArrayList<>(); + PythonVariable output = new PythonVariable<>("z", arrType); + outputs.add(output); + String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)"; + + PythonJob job = new PythonJob("job1", code, false); + + job.exec(inputs, outputs); + + INDArray z2 = output.getValue(); + + if (dataType == DataType.BFLOAT16){ + z2 = z2.castTo(DataType.FLOAT); + } + + Assert.assertEquals(z, z2); + + } + + @Test + public void testNumpyJobReturnAllVariables(){ + PythonContextManager.deleteNonMainContexts(); + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, 2, 3); + INDArray y = Nd4j.zeros(dataType, 2, 3); + INDArray z = (dataType == DataType.BOOL)?x:x.mul(y.add(2)); + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + String code = (dataType == DataType.BOOL)?"z = x":"z = x * (y + 2)"; + + PythonJob job = new PythonJob("job1", code, false); + List outputs = job.execAndReturnAllVariables(inputs); + + INDArray x2 = (INDArray) outputs.get(0).getValue(); + INDArray y2 = (INDArray) outputs.get(1).getValue(); + INDArray z2 = (INDArray) outputs.get(2).getValue(); + + if (dataType == DataType.BFLOAT16){ + x = x.castTo(DataType.FLOAT); + y = y.castTo(DataType.FLOAT); + z = z.castTo(DataType.FLOAT); + } + Assert.assertEquals(x, x2); + Assert.assertEquals(y, y2); + Assert.assertEquals(z, z2); + + } + + + @Test + public void testMultipleNumpyJobsParallel(){ + PythonContextManager.deleteNonMainContexts(); + String code1 =(dataType == DataType.BOOL)?"z = x":"z = x + y"; + PythonJob job1 = new PythonJob("job1", code1, false); + + String code2 =(dataType == DataType.BOOL)?"z = y":"z = x - y"; + PythonJob job2 = new PythonJob("job2", code2, false); + + List inputs = new ArrayList<>(); + INDArray x = Nd4j.ones(dataType, 2, 3); + INDArray y = Nd4j.zeros(dataType, 2, 3); + INDArray z1 = (dataType == DataType.BOOL)?x:x.add(y); + z1 = (dataType == DataType.BFLOAT16)? z1.castTo(DataType.FLOAT): z1; + INDArray z2 = (dataType == DataType.BOOL)?y:x.sub(y); + z2 = (dataType == DataType.BFLOAT16)? z2.castTo(DataType.FLOAT): z2; + PythonType arrType = PythonTypes.get("numpy.ndarray"); + inputs.add(new PythonVariable<>("x", arrType, x)); + inputs.add(new PythonVariable<>("y", arrType, y)); + + + List outputs = new ArrayList<>(); + + outputs.add(new PythonVariable<>("z", arrType)); + + job1.exec(inputs, outputs); + + assertEquals(z1, outputs.get(0).getValue()); + + + job2.exec(inputs, outputs); + + assertEquals(z2, outputs.get(0).getValue()); + + } + + + @Test + public synchronized void testNumpyJobSetupRun(){ + if (dataType == DataType.BOOL)return; + PythonContextManager.deleteNonMainContexts(); + String code = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n"; + + PythonJob job = new PythonJob("job1", code, true); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); + job.exec(inputs, outputs); + + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10), + outputs.get(0).getValue()); + + + inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + + + outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); + + job.exec(inputs, outputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12), + outputs.get(0).getValue()); + + + } + @Test + public void testNumpyJobSetupRunAndReturnAllVariables(){ + if (dataType == DataType.BOOL)return; + PythonContextManager.deleteNonMainContexts(); + String code = "five=None\n" + + "c=None\n"+ + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " global c\n" + + " c = a + b + five\n"; + PythonJob job = new PythonJob("job1", code, true); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + + List outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10), + outputs.get(1).getValue()); + + + inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + + + outputs = job.execAndReturnAllVariables(inputs); + + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12), + outputs.get(1).getValue()); + + + } + + @Test + public void testMultipleNumpyJobsSetupRunParallel(){ + if (dataType == DataType.BOOL)return; + PythonContextManager.deleteNonMainContexts(); + + String code1 = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n"; + PythonJob job1 = new PythonJob("job1", code1, true); + + String code2 = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b - five\n"+ + " return {'c':c}\n\n"; + PythonJob job2 = new PythonJob("job2", code2, true); + + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(2))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + + + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); + + job1.exec(inputs, outputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(10), + outputs.get(0).getValue()); + + + job2.exec(inputs, outputs); + + assertEquals(Nd4j.zeros((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3), + outputs.get(0).getValue()); + + + inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("a", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("b", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + + outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("c", NumpyArray.INSTANCE)); + + + job1.exec(inputs, outputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(12), + outputs.get(0).getValue()); + + + job2.exec(inputs, outputs); + + assertEquals(Nd4j.ones((dataType == DataType.BFLOAT16)? DataType.FLOAT: dataType, 2, 3).mul(2), + outputs.get(0).getValue()); + + + } + +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java new file mode 100644 index 000000000..52ccd1fd0 --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyMultiThreadTest.java @@ -0,0 +1,194 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +import org.eclipse.python4j.*; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import javax.annotation.concurrent.NotThreadSafe; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + + +@NotThreadSafe +@RunWith(Parameterized.class) +public class PythonNumpyMultiThreadTest { + private DataType dataType; + + public PythonNumpyMultiThreadTest(DataType dataType) { + this.dataType = dataType; + } + + @Parameterized.Parameters(name = "{index}: Testing with DataType={0}") + public static DataType[] params() { + return new DataType[]{ +// DataType.BOOL, +// DataType.FLOAT16, +// DataType.BFLOAT16, + DataType.FLOAT, + DataType.DOUBLE, +// DataType.INT8, +// DataType.INT16, + DataType.INT32, + DataType.INT64, +// DataType.UINT8, +// DataType.UINT16, +// DataType.UINT32, +// DataType.UINT64 + }; + } + + + @Test + public void testMultiThreading1() throws Throwable { + final List exceptions = Collections.synchronizedList(new ArrayList()); + Runnable runnable = new Runnable() { + @Override + public void run() { + try (PythonGIL gil = PythonGIL.lock()) { + try (PythonGC gc = PythonGC.watch()) { + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + PythonVariable out = new PythonVariable<>("z", NumpyArray.INSTANCE); + String code = "z = x + y"; + PythonExecutioner.exec(code, inputs, Collections.singletonList(out)); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), out.getValue()); + } + } catch (Throwable e) { + exceptions.add(e); + } + } + }; + + int numThreads = 10; + Thread[] threads = new Thread[numThreads]; + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(runnable); + } + for (int i = 0; i < threads.length; i++) { + threads[i].start(); + } + Thread.sleep(100); + for (int i = 0; i < threads.length; i++) { + threads[i].join(); + } + if (!exceptions.isEmpty()) { + throw (exceptions.get(0)); + } + + } + + @Test + public void testMultiThreading2() throws Throwable { + final List exceptions = Collections.synchronizedList(new ArrayList()); + Runnable runnable = new Runnable() { + @Override + public void run() { + try (PythonGIL gil = PythonGIL.lock()) { + try (PythonGC gc = PythonGC.watch()) { + PythonContextManager.reset(); + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("x", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(3))); + inputs.add(new PythonVariable<>("y", NumpyArray.INSTANCE, Nd4j.ones(dataType, 2, 3).mul(4))); + String code = "z = x + y"; + List outputs = PythonExecutioner.execAndReturnAllVariables(code, inputs); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(3), outputs.get(0).getValue()); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(4), outputs.get(1).getValue()); + Assert.assertEquals(Nd4j.ones(dataType, 2, 3).mul(7), outputs.get(2).getValue()); + } + } catch (Throwable e) { + exceptions.add(e); + } + } + }; + + int numThreads = 10; + Thread[] threads = new Thread[numThreads]; + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(runnable); + } + for (int i = 0; i < threads.length; i++) { + threads[i].start(); + } + Thread.sleep(100); + for (int i = 0; i < threads.length; i++) { + threads[i].join(); + } + if (!exceptions.isEmpty()) { + throw (exceptions.get(0)); + } + } + + @Test + public void testMultiThreading3() throws Throwable { + PythonContextManager.deleteNonMainContexts(); + + String code = "c = a + b"; + final PythonJob job = new PythonJob("job1", code, false); + + final List exceptions = Collections.synchronizedList(new ArrayList()); + + class JobThread extends Thread { + private INDArray a, b, c; + + public JobThread(INDArray a, INDArray b, INDArray c) { + this.a = a; + this.b = b; + this.c = c; + } + + @Override + public void run() { + try { + PythonVariable out = new PythonVariable<>("c", NumpyArray.INSTANCE); + job.exec(Arrays.asList(new PythonVariable<>("a", NumpyArray.INSTANCE, a), + new PythonVariable<>("b", NumpyArray.INSTANCE, b)), + Collections.singletonList(out)); + Assert.assertEquals(c, out.getValue()); + } catch (Exception e) { + exceptions.add(e); + } + + } + } + int numThreads = 10; + JobThread[] threads = new JobThread[numThreads]; + for (int i = 0; i < threads.length; i++) { + threads[i] = new JobThread(Nd4j.zeros(dataType, 2, 3).add(i), Nd4j.zeros(dataType, 2, 3).add(i + 3), + Nd4j.zeros(dataType, 2, 3).add(2 * i + 3)); + } + + for (int i = 0; i < threads.length; i++) { + threads[i].start(); + } + Thread.sleep(100); + for (int i = 0; i < threads.length; i++) { + threads[i].join(); + } + + if (!exceptions.isEmpty()) { + throw (exceptions.get(0)); + } + } +} diff --git a/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java b/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java new file mode 100644 index 000000000..d3c649c8d --- /dev/null +++ b/python4j/python4j-numpy/src/test/java/PythonNumpyServiceLoaderTest.java @@ -0,0 +1,41 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +import org.eclipse.python4j.*; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import javax.annotation.concurrent.NotThreadSafe; +import java.util.ArrayList; +import java.util.List; + +@NotThreadSafe +public class PythonNumpyServiceLoaderTest { + + @Test + public void testServiceLoader(){ + Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.get("numpy.ndarray")); + Assert.assertEquals(NumpyArray.INSTANCE, PythonTypes.getPythonTypeForJavaObject(Nd4j.zeros(1))); + } + + +}