diff --git a/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java index 089c8aefe..d23c70dde 100644 --- a/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java +++ b/python4j/python4j-core/src/main/java/org/nd4j/python4j/PythonTypes.java @@ -258,7 +258,7 @@ public class PythonTypes { return ret; }else if (javaObject instanceof byte[]){ byte[] arr = (byte[]) javaObject; - for (int x : arr) ret.add(x); + for (int x : arr) ret.add(x & 0xff); return ret; } else if (javaObject instanceof long[]) { long[] arr = (long[]) javaObject; diff --git a/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java b/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java index 94423f7de..5080b8b35 100644 --- a/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java +++ b/python4j/python4j-core/src/test/java/PythonPrimitiveTypesTest.java @@ -81,16 +81,31 @@ public class PythonPrimitiveTypesTest { } @Test public void testBytes() { + byte[] bytes = new byte[256]; + for (int i = 0; i < 256; i++) { + bytes[i] = (byte) i; + } + List inputs = new ArrayList<>(); + inputs.add(new PythonVariable<>("b1", PythonTypes.BYTES, bytes)); + List outputs = new ArrayList<>(); + outputs.add(new PythonVariable<>("b2", PythonTypes.BYTES)); + String code = "b2=b1"; + PythonExecutioner.exec(code, inputs, outputs); + Assert.assertArrayEquals(bytes, (byte[]) outputs.get(0).getValue()); + } + + @Test + public void testBytes2() { byte[] bytes = new byte[]{97, 98, 99}; List inputs = new ArrayList<>(); - inputs.add(new PythonVariable<>("buff", PythonTypes.BYTES, bytes)); + inputs.add(new PythonVariable<>("b1", PythonTypes.BYTES, bytes)); List outputs = new ArrayList<>(); outputs.add(new PythonVariable<>("s1", PythonTypes.STR)); - outputs.add(new PythonVariable<>("buff2", PythonTypes.BYTES)); - String code = "s1 = ''.join(chr(c) for c in buff)\nbuff2=b'def'"; + outputs.add(new PythonVariable<>("b2", PythonTypes.BYTES)); + String code = "s1 = ''.join(chr(c) for c in b1)\nb2=b'def'"; PythonExecutioner.exec(code, inputs, outputs); Assert.assertEquals("abc", outputs.get(0).getValue()); - Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[])outputs.get(1).getValue()); + Assert.assertArrayEquals(new byte[]{100, 101, 102}, (byte[]) outputs.get(1).getValue()); } }