diff --git a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java
index 68a012b72..c1eff1dce 100644
--- a/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java
+++ b/deeplearning4j/deeplearning4j-scaleout/spark/dl4j-spark-parameterserver/src/test/java/org/deeplearning4j/spark/parameterserver/train/GradientSharingTrainingTest.java
@@ -141,7 +141,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
SparkComputationGraph sparkNet = new SparkComputationGraph(sc, conf, tm);
sparkNet.setCollectTrainingStats(tm.getIsCollectTrainingStats());
- System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
+// System.out.println(Arrays.toString(sparkNet.getNetwork().params().get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
File f = testDir.newFolder();
DataSetIterator iter = new MnistDataSetIterator(16, true, 12345);
int count = 0;
@@ -208,10 +208,10 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
}
INDArray paramsAfter = after.params();
- System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
- System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
- System.out.println(Arrays.toString(
- Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
+// System.out.println(Arrays.toString(paramsBefore.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
+// System.out.println(Arrays.toString(paramsAfter.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
+// System.out.println(Arrays.toString(
+// Transforms.abs(paramsAfter.sub(paramsBefore)).get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 256)).dup().data().asFloat()));
assertNotEquals(paramsBefore, paramsAfter);
@@ -235,7 +235,7 @@ public class GradientSharingTrainingTest extends BaseSparkTest {
}
- @Test
+ @Test @Ignore //AB https://github.com/eclipse/deeplearning4j/issues/8985
public void differentNetsTrainingTest() throws Exception {
int batch = 3;
diff --git a/python4j/pom.xml b/python4j/pom.xml
index 57af8f1bb..1fe50344f 100644
--- a/python4j/pom.xml
+++ b/python4j/pom.xml
@@ -41,10 +41,14 @@
provided
+ org.slf4j
+ slf4j-api
+ 1.6.6
+
ch.qos.logback
logback-classic
${logback.version}
- test
+ test
junit
@@ -62,5 +66,10 @@
jsr305
3.0.2
+
+ org.slf4j
+ slf4j-api
+ 1.6.6
+
\ No newline at end of file
diff --git a/python4j/python4j-core/pom.xml b/python4j/python4j-core/pom.xml
index b429d8272..e74d32392 100644
--- a/python4j/python4j-core/pom.xml
+++ b/python4j/python4j-core/pom.xml
@@ -39,6 +39,5 @@
cpython-platform
${cpython-platform.version}
-
\ No newline at end of file
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java
index a34d8a239..5675d0864 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java
+++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonContextManager.java
@@ -19,8 +19,10 @@ package org.eclipse.python4j;
import javax.lang.model.SourceVersion;
+import java.io.Closeable;
import java.util.HashSet;
import java.util.Set;
+import java.util.UUID;
import java.util.concurrent.atomic.AtomicBoolean;
/**
@@ -46,6 +48,31 @@ public class PythonContextManager {
init();
}
+
+ public static class Context implements Closeable{
+ private final String name;
+ private final String previous;
+ private final boolean temp;
+ public Context(){
+ name = "temp_" + UUID.randomUUID().toString().replace("-", "_");
+ temp = true;
+ previous = getCurrentContext();
+ setContext(name);
+ }
+ public Context(String name){
+ this.name = name;
+ temp = false;
+ previous = getCurrentContext();
+ setContext(name);
+ }
+
+ @Override
+ public void close(){
+ setContext(previous);
+ if (temp) deleteContext(name);
+ }
+ }
+
private static void init() {
if (init.get()) return;
new PythonExecutioner();
@@ -190,6 +217,7 @@ public class PythonContextManager {
setContext(tempContext);
deleteContext(currContext);
setContext(currContext);
+ deleteContext(tempContext);
}
/**
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java
index 57e1a22ae..542778f76 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java
+++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonExecutioner.java
@@ -25,6 +25,7 @@ import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
@@ -42,7 +43,6 @@ public class PythonExecutioner {
private final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.eclipse.python4j.path";
private final static String JAVACPP_PYTHON_APPEND_TYPE = "org.eclipse.python4j.path.append";
private final static String DEFAULT_APPEND_TYPE = "before";
-
static {
init();
}
@@ -55,6 +55,11 @@ public class PythonExecutioner {
initPythonPath();
PyEval_InitThreads();
Py_InitializeEx(0);
+ for (PythonType type: PythonTypes.get()){
+ type.init();
+ }
+ // Constructors of custom types may contain initialization code that should
+ // run on the main the thread.
}
/**
@@ -110,6 +115,8 @@ public class PythonExecutioner {
getVariables(Arrays.asList(pyVars));
}
+
+
/**
* Gets the variable with the given name from the interpreter.
*
@@ -205,9 +212,9 @@ public class PythonExecutioner {
*
* @return
*/
- public static List getAllVariables() {
+ public static PythonVariables getAllVariables() {
PythonGIL.assertThreadSafe();
- List ret = new ArrayList<>();
+ PythonVariables ret = new PythonVariables();
PyObject main = PyImport_ImportModule("__main__");
PyObject globals = PyModule_GetDict(main);
PyObject keys = PyDict_Keys(globals);
@@ -259,7 +266,7 @@ public class PythonExecutioner {
* @param inputs
* @return
*/
- public static List execAndReturnAllVariables(String code, List inputs) {
+ public static PythonVariables execAndReturnAllVariables(String code, List inputs) {
setVariables(inputs);
simpleExec(getWrappedCode(code));
return getAllVariables();
@@ -271,7 +278,7 @@ public class PythonExecutioner {
* @param code
* @return
*/
- public static List execAndReturnAllVariables(String code) {
+ public static PythonVariables execAndReturnAllVariables(String code) {
simpleExec(getWrappedCode(code));
return getAllVariables();
}
@@ -279,25 +286,22 @@ public class PythonExecutioner {
private static synchronized void initPythonPath() {
try {
String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY);
+
+ List packagesList = new ArrayList<>();
+ packagesList.addAll(Arrays.asList(cachePackages()));
+ for (PythonType type: PythonTypes.get()){
+ packagesList.addAll(Arrays.asList(type.packages()));
+ }
+ //// TODO: fix in javacpp
+ packagesList.add(new File(python.cachePackage(), "site-packages"));
+
+ File[] packages = packagesList.toArray(new File[0]);
+
if (path == null) {
- File[] packages = cachePackages();
-
- //// TODO: fix in javacpp
- File sitePackagesWindows = new File(python.cachePackage(), "site-packages");
- File[] packages2 = new File[packages.length + 1];
- for (int i = 0; i < packages.length; i++) {
- //System.out.println(packages[i].getAbsolutePath());
- packages2[i] = packages[i];
- }
- packages2[packages.length] = sitePackagesWindows;
- //System.out.println(sitePackagesWindows.getAbsolutePath());
- packages = packages2;
- //////////
-
Py_SetPath(packages);
} else {
StringBuffer sb = new StringBuffer();
- File[] packages = cachePackages();
+
JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase());
switch (pathAppendValue) {
case BEFORE:
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java
index 46b3db431..074be294a 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java
+++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonGIL.java
@@ -90,4 +90,8 @@ public class PythonGIL implements AutoCloseable {
PyEval_SaveThread();
PyEval_RestoreThread(mainThreadState);
}
+
+ public static boolean locked(){
+ return acquired.get();
+ }
}
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java
index cdbb1b81d..0818de890 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java
+++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonJob.java
@@ -20,25 +20,29 @@ package org.eclipse.python4j;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
+import lombok.extern.slf4j.Slf4j;
import javax.annotation.Nonnull;
import java.util.List;
+import java.util.concurrent.atomic.AtomicBoolean;
-@Data
-@NoArgsConstructor
/**
* PythonJob is the right abstraction for executing multiple python scripts
* in a multi thread stateful environment. The setup-and-run mode allows your
* "setup" code (imports, model loading etc) to be executed only once.
*/
+@Data
+@Slf4j
public class PythonJob {
+
private String code;
private String name;
private String context;
- private boolean setupRunMode;
+ private final boolean setupRunMode;
private PythonObject runF;
+ private final AtomicBoolean setupDone = new AtomicBoolean(false);
static {
new PythonExecutioner();
@@ -63,7 +67,6 @@ public class PythonJob {
if (PythonContextManager.hasContext(context)) {
throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!");
}
- if (setupRunMode) setup();
}
@@ -71,17 +74,18 @@ public class PythonJob {
* Clears all variables in current context and calls setup()
*/
public void clearState(){
- String context = this.context;
- PythonContextManager.setContext("main");
- PythonContextManager.deleteContext(context);
- this.context = context;
+ PythonContextManager.setContext(this.context);
+ PythonContextManager.reset();
+ setupDone.set(false);
setup();
}
public void setup(){
+ if (setupDone.get()) return;
try (PythonGIL gil = PythonGIL.lock()) {
PythonContextManager.setContext(context);
PythonObject runF = PythonExecutioner.getVariable("run");
+
if (runF == null || runF.isNone() || !Python.callable(runF)) {
PythonExecutioner.exec(code);
runF = PythonExecutioner.getVariable("run");
@@ -98,10 +102,12 @@ public class PythonJob {
if (!setupF.isNone()) {
setupF.call();
}
+ setupDone.set(true);
}
}
public void exec(List inputs, List outputs) {
+ if (setupRunMode)setup();
try (PythonGIL gil = PythonGIL.lock()) {
try (PythonGC _ = PythonGC.watch()) {
PythonContextManager.setContext(context);
@@ -139,6 +145,7 @@ public class PythonJob {
}
public List execAndReturnAllVariables(List inputs){
+ if (setupRunMode)setup();
try (PythonGIL gil = PythonGIL.lock()) {
try (PythonGC _ = PythonGC.watch()) {
PythonContextManager.setContext(context);
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java
index f8ec17ed9..69252a5f7 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java
+++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonObject.java
@@ -147,7 +147,8 @@ public class PythonObject {
}
PythonObject pyArgs;
PythonObject pyKwargs;
- if (args == null) {
+
+ if (args == null || args.isEmpty()) {
pyArgs = new PythonObject(PyTuple_New(0));
} else {
PythonObject argsList = PythonTypes.convert(args);
@@ -158,6 +159,7 @@ public class PythonObject {
} else {
pyKwargs = PythonTypes.convert(kwargs);
}
+
PythonObject ret = new PythonObject(
PyObject_Call(
nativePythonObject,
@@ -165,7 +167,9 @@ public class PythonObject {
pyKwargs == null ? null : pyKwargs.nativePythonObject
)
);
+
PythonGC.keep(ret);
+
return ret;
}
@@ -241,4 +245,48 @@ public class PythonObject {
PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject);
}
+
+ public PythonObject abs(){
+ return new PythonObject(PyNumber_Absolute(nativePythonObject));
+ }
+ public PythonObject add(PythonObject pythonObject){
+ return new PythonObject(PyNumber_Add(nativePythonObject, pythonObject.nativePythonObject));
+ }
+ public PythonObject sub(PythonObject pythonObject){
+ return new PythonObject(PyNumber_Subtract(nativePythonObject, pythonObject.nativePythonObject));
+ }
+ public PythonObject mod(PythonObject pythonObject){
+ return new PythonObject(PyNumber_Divmod(nativePythonObject, pythonObject.nativePythonObject));
+ }
+ public PythonObject mul(PythonObject pythonObject){
+ return new PythonObject(PyNumber_Multiply(nativePythonObject, pythonObject.nativePythonObject));
+ }
+ public PythonObject trueDiv(PythonObject pythonObject){
+ return new PythonObject(PyNumber_TrueDivide(nativePythonObject, pythonObject.nativePythonObject));
+ }
+ public PythonObject floorDiv(PythonObject pythonObject){
+ return new PythonObject(PyNumber_FloorDivide(nativePythonObject, pythonObject.nativePythonObject));
+ }
+ public PythonObject matMul(PythonObject pythonObject){
+ return new PythonObject(PyNumber_MatrixMultiply(nativePythonObject, pythonObject.nativePythonObject));
+ }
+
+ public void addi(PythonObject pythonObject){
+ PyNumber_InPlaceAdd(nativePythonObject, pythonObject.nativePythonObject);
+ }
+ public void subi(PythonObject pythonObject){
+ PyNumber_InPlaceSubtract(nativePythonObject, pythonObject.nativePythonObject);
+ }
+ public void muli(PythonObject pythonObject){
+ PyNumber_InPlaceMultiply(nativePythonObject, pythonObject.nativePythonObject);
+ }
+ public void trueDivi(PythonObject pythonObject){
+ PyNumber_InPlaceTrueDivide(nativePythonObject, pythonObject.nativePythonObject);
+ }
+ public void floorDivi(PythonObject pythonObject){
+ PyNumber_InPlaceFloorDivide(nativePythonObject, pythonObject.nativePythonObject);
+ }
+ public void matMuli(PythonObject pythonObject){
+ PyNumber_InPlaceMatrixMultiply(nativePythonObject, pythonObject.nativePythonObject);
+ }
}
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java
new file mode 100644
index 000000000..0ca17fb49
--- /dev/null
+++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonProcess.java
@@ -0,0 +1,127 @@
+/*******************************************************************************
+ * Copyright (c) 2020 Konduit K.K.
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Apache License, Version 2.0 which is available at
+ * https://www.apache.org/licenses/LICENSE-2.0.
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations
+ * under the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ ******************************************************************************/
+
+
+package org.eclipse.python4j;
+
+import org.apache.commons.io.IOUtils;
+import org.bytedeco.javacpp.Loader;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+
+public class PythonProcess {
+ private static String pythonExecutable = Loader.load(org.bytedeco.cpython.python.class);
+ public static String runAndReturn(String... arguments)throws IOException, InterruptedException{
+ String[] allArgs = new String[arguments.length + 1];
+ for (int i = 0; i < arguments.length; i++){
+ allArgs[i + 1] = arguments[i];
+ }
+ allArgs[0] = pythonExecutable;
+ ProcessBuilder pb = new ProcessBuilder(allArgs);
+ Process process = pb.start();
+ String out = IOUtils.toString(process.getInputStream(), StandardCharsets.UTF_8);
+ process.waitFor();
+ return out;
+
+ }
+
+ public static void run(String... arguments)throws IOException, InterruptedException{
+ String[] allArgs = new String[arguments.length + 1];
+ for (int i = 0; i < arguments.length; i++){
+ allArgs[i + 1] = arguments[i];
+ }
+ allArgs[0] = pythonExecutable;
+ ProcessBuilder pb = new ProcessBuilder(allArgs);
+ pb.inheritIO().start().waitFor();
+ }
+ public static void pipInstall(String packageName) throws PythonException{
+ try{
+ run("-m", "pip", "install", packageName);
+ }catch(Exception e){
+ throw new PythonException("Error installing package " + packageName, e);
+ }
+
+ }
+
+ public static void pipInstall(String packageName, String version){
+ pipInstall(packageName + "==" + version);
+ }
+
+ public static void pipUninstall(String packageName) throws PythonException{
+ try{
+ run("-m", "pip", "uninstall", packageName);
+ }catch(Exception e){
+ throw new PythonException("Error uninstalling package " + packageName, e);
+ }
+
+ }
+ public static void pipInstallFromGit(String gitRepoUrl){
+ if (!gitRepoUrl.contains("://")){
+ gitRepoUrl = "git://" + gitRepoUrl;
+ }
+ try{
+ run("-m", "pip", "install", "git+", gitRepoUrl);
+ }catch(Exception e){
+ throw new PythonException("Error installing package from " + gitRepoUrl, e);
+ }
+
+ }
+
+ public static String getPackageVersion(String packageName){
+ String out;
+ try{
+ out = runAndReturn("-m", "pip", "show", packageName);
+ } catch (Exception e){
+ throw new PythonException("Error finding version for package " + packageName, e);
+ }
+
+ if (!out.contains("Version: ")){
+ throw new PythonException("Can't find package " + packageName);
+ }
+ String pkgVersion = out.split("Version: ")[1].split(System.lineSeparator())[0];
+ return pkgVersion;
+ }
+
+ public static boolean isPackageInstalled(String packageName){
+ try{
+ String out = runAndReturn("-m", "pip", "show", packageName);
+ return !out.isEmpty();
+ }catch (Exception e){
+ throw new PythonException("Error checking if package is installed: " +packageName, e);
+ }
+
+ }
+
+ public static void pipInstallFromRequirementsTxt(String path){
+ try{
+ run("-m", "pip", "install","-r", path);
+ }catch (Exception e){
+ throw new PythonException("Error installing packages from " + path, e);
+ }
+ }
+
+ public static void pipInstallFromSetupScript(String path, boolean inplace){
+
+ try{
+ run(path, inplace?"develop":"install");
+ }catch (Exception e){
+ throw new PythonException("Error installing package from " + path, e);
+ }
+
+ }
+
+}
\ No newline at end of file
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java
index b4806aa37..47b725cd5 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java
+++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonType.java
@@ -17,6 +17,8 @@
package org.eclipse.python4j;
+import java.io.File;
+
public abstract class PythonType {
private final String name;
@@ -43,5 +45,25 @@ public abstract class PythonType {
return name;
}
+ @Override
+ public boolean equals(Object obj){
+ if (!(obj instanceof PythonType)){
+ return false;
+ }
+ PythonType other = (PythonType)obj;
+ return this.getClass().equals(other.getClass()) && this.name.equals(other.name);
+ }
+
+ public PythonObject pythonType(){
+ return null;
+ }
+
+ public File[] packages(){
+ return new File[0];
+ }
+
+ public void init(){ //not to be called from constructor
+
+ }
}
diff --git a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java
index 0dc20f712..cd7ac7d7c 100644
--- a/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java
+++ b/python4j/python4j-core/src/main/java/org/eclipse/python4j/PythonTypes.java
@@ -18,7 +18,16 @@ package org.eclipse.python4j;
import org.bytedeco.cpython.PyObject;
+import org.bytedeco.javacpp.BytePointer;
+import org.bytedeco.javacpp.Loader;
+import org.bytedeco.javacpp.Pointer;
+import sun.misc.Unsafe;
+import sun.nio.ch.DirectBuffer;
+import java.lang.reflect.Field;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
import java.util.*;
import static org.bytedeco.cpython.global.python.*;
@@ -28,7 +37,7 @@ public class PythonTypes {
private static List getPrimitiveTypes() {
- return Arrays.asList(STR, INT, FLOAT, BOOL);
+ return Arrays.asList(STR, INT, FLOAT, BOOL, MEMORYVIEW);
}
private static List getCollectionTypes() {
@@ -36,8 +45,13 @@ public class PythonTypes {
}
private static List getExternalTypes() {
- //TODO service loader
- return new ArrayList<>();
+ List ret = new ArrayList<>();
+ ServiceLoader sl = ServiceLoader.load(PythonType.class);
+ Iterator iter = sl.iterator();
+ while (iter.hasNext()) {
+ ret.add(iter.next());
+ }
+ return ret;
}
public static List get() {
@@ -48,15 +62,17 @@ public class PythonTypes {
return ret;
}
- public static PythonType get(String name) {
+ public static PythonType get(String name) {
for (PythonType pt : get()) {
if (pt.getName().equals(name)) { // TODO use map instead?
return pt;
}
+
}
throw new PythonException("Unknown python type: " + name);
}
+
public static PythonType getPythonTypeForJavaObject(Object javaObject) {
for (PythonType pt : get()) {
if (pt.accepts(javaObject)) {
@@ -66,7 +82,7 @@ public class PythonTypes {
throw new PythonException("Unable to find python type for java type: " + javaObject.getClass());
}
- public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) {
+ public static PythonType getPythonTypeForPythonObject(PythonObject pythonObject) {
PyObject pyType = PyObject_Type(pythonObject.getNativePythonObject());
try {
String pyTypeStr = PythonTypes.STR.toJava(new PythonObject(pyType, false));
@@ -75,6 +91,14 @@ public class PythonTypes {
String pyTypeStr2 = "";
if (pyTypeStr.equals(pyTypeStr2)) {
return pt;
+ } else {
+ try (PythonGC gc = PythonGC.watch()) {
+ PythonObject pyType2 = pt.pythonType();
+ if (pyType2 != null && Python.isinstance(pythonObject, pyType2)) {
+ return pt;
+ }
+ }
+
}
}
throw new PythonException("Unable to find converter for python object of type " + pyTypeStr);
@@ -212,12 +236,49 @@ public class PythonTypes {
public static final PythonType LIST = new PythonType("list", List.class) {
+ @Override
+ public boolean accepts(Object javaObject) {
+ return (javaObject instanceof List || javaObject.getClass().isArray());
+ }
+
@Override
public List adapt(Object javaObject) {
if (javaObject instanceof List) {
return (List) javaObject;
- } else if (javaObject instanceof Object[]) {
- return Arrays.asList((Object[]) javaObject);
+ } else if (javaObject.getClass().isArray()) {
+ List