From f6b3032def5f5c37eab30f04a345c0e31aeaf8f3 Mon Sep 17 00:00:00 2001 From: Fariz Rahman Date: Tue, 4 Feb 2020 13:23:59 +0400 Subject: [PATCH 1/9] Python executioner 2.0 (#134) * wrapper * builtins * context mgr * direct place * ret all var * call fix * fix ndarray serde * jobs * try-with gil management * cleanup * exec tests passing * list tests * transforms test passing * all pass * headers * dict fixes+test * python path * bool isinstance * job tests * nits * transform fix+test * transform tests * leak fixes * more mem leak fixes * more fixes * nits for adam * PythonJob lombok builder * checked exceptions * more nits * small leak fix * more nits * pythonexceptions * fix jvm crash when bad python code * Exception->PythonException * Add support for boolean types in arrow records and ability to cast from float, double to int for TypeConversion (#178) * nits for alex * update tests * fix test * all pass * refacc * rem old code * dtypes * bytes working+exception pass through+cleanup (#209) * more bytes tests * header * rem dummy test * rem bad import * alex nits + refacc * Small error fixes (wrong type in msg) + minor formatting Signed-off-by: AlexDBlack * use actual python type names (dictionary->dict, boolean->bool) Co-authored-by: Shams Ul Azeem Co-authored-by: Alex Black --- .../schema/conversion/TypeConversion.java | 2 +- .../org/datavec/arrow/ArrowConverter.java | 8 +- .../transform/TestPythonTransformProcess.java | 39 +- .../java/org/datavec/python/NumpyArray.java | 39 +- .../main/java/org/datavec/python/Python.java | 265 ++++ .../org/datavec/python/PythonCondition.java | 52 +- .../datavec/python/PythonContextManager.java | 188 +++ .../org/datavec/python/PythonException.java | 44 + .../org/datavec/python/PythonExecutioner.java | 1272 ++++------------- .../java/org/datavec/python/PythonGIL.java | 68 + .../java/org/datavec/python/PythonJob.java | 171 +++ .../java/org/datavec/python/PythonObject.java | 554 +++++++ .../org/datavec/python/PythonTransform.java | 159 +-- .../java/org/datavec/python/PythonType.java | 238 +++ .../java/org/datavec/python/PythonUtils.java | 169 ++- .../org/datavec/python/PythonVariables.java | 374 ++--- .../src/main/resources/pythonexec/__init__.py | 0 .../main/resources/pythonexec/clear_vars.py | 5 - .../main/resources/pythonexec/input_code.py | 1 - .../main/resources/pythonexec/outputcode.py | 20 - .../src/main/resources/pythonexec/patch0.py | 202 --- .../src/main/resources/pythonexec/patch1.py | 172 --- .../main/resources/pythonexec/pythonexec.py | 4 +- .../resources/pythonexec/serialize_array.py | 50 - .../python/TestPythonContextManager.java | 87 ++ .../org/datavec/python/TestPythonDict.java | 64 + .../python/TestPythonExecutionSandbox.java | 75 - .../datavec/python/TestPythonExecutioner.java | 148 +- .../org/datavec/python/TestPythonJob.java | 326 +++++ .../org/datavec/python/TestPythonList.java | 108 ++ .../datavec/python/TestPythonSetupAndRun.java | 27 - .../datavec/python/TestPythonVariables.java | 50 +- .../java/org/datavec/python/TestSerde.java | 2 +- 33 files changed, 2907 insertions(+), 2076 deletions(-) create mode 100644 datavec/datavec-python/src/main/java/org/datavec/python/Python.java create mode 100644 datavec/datavec-python/src/main/java/org/datavec/python/PythonContextManager.java create mode 100644 datavec/datavec-python/src/main/java/org/datavec/python/PythonException.java create mode 100644 datavec/datavec-python/src/main/java/org/datavec/python/PythonGIL.java create mode 100644 datavec/datavec-python/src/main/java/org/datavec/python/PythonJob.java create mode 100644 datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java create mode 100644 datavec/datavec-python/src/main/java/org/datavec/python/PythonType.java delete mode 100644 datavec/datavec-python/src/main/resources/pythonexec/__init__.py delete mode 100644 datavec/datavec-python/src/main/resources/pythonexec/clear_vars.py delete mode 100644 datavec/datavec-python/src/main/resources/pythonexec/input_code.py delete mode 100644 datavec/datavec-python/src/main/resources/pythonexec/outputcode.py delete mode 100644 datavec/datavec-python/src/main/resources/pythonexec/patch0.py delete mode 100644 datavec/datavec-python/src/main/resources/pythonexec/patch1.py delete mode 100644 datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/TestPythonContextManager.java create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/TestPythonDict.java delete mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/TestPythonJob.java create mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/TestPythonList.java delete mode 100644 datavec/datavec-python/src/test/java/org/datavec/python/TestPythonSetupAndRun.java diff --git a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java index afd128669..1f7e938b6 100644 --- a/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java +++ b/datavec/datavec-api/src/main/java/org/datavec/api/transform/schema/conversion/TypeConversion.java @@ -45,7 +45,7 @@ public class TypeConversion { } public int convertInt(String o) { - return Integer.parseInt(o); + return (int) Double.parseDouble(o); } public double convertDouble(Writable writable) { diff --git a/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java b/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java index 48f9474d5..f66475357 100644 --- a/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java +++ b/datavec/datavec-arrow/src/main/java/org/datavec/arrow/ArrowConverter.java @@ -725,7 +725,6 @@ public class ArrowConverter { case Time: ret.add(timeVectorOf(bufferAllocator,schema.getName(i),numRows)); break; case NDArray: ret.add(ndarrayVectorOf(bufferAllocator,schema.getName(i),numRows)); break; default: throw new IllegalArgumentException("Illegal type found for creation of field vectors" + schema.getType(i)); - } } @@ -802,8 +801,13 @@ public class ArrowConverter { //for proper offsets ByteBuffer byteBuffer = BinarySerde.toByteBuffer(arr.get()); nd4jArrayVector.setSafe(row,byteBuffer,0,byteBuffer.capacity()); + case Boolean: + BitVector bitVector = (BitVector) fieldVector; + if(value instanceof Boolean) + bitVector.set(row, (boolean) value ? 1 : 0); + else + bitVector.set(row, ((BooleanWritable) value).get() ? 1 : 0); break; - } }catch(Exception e) { log.warn("Unable to set value at row " + row); diff --git a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java index 37df8ae52..7a6ab3f03 100644 --- a/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java +++ b/datavec/datavec-local/src/test/java/org/datavec/local/transforms/transform/TestPythonTransformProcess.java @@ -315,7 +315,7 @@ public class TestPythonTransformProcess { } @Test - public void testNumpyTransform() throws Exception { + public void testNumpyTransform() { PythonTransform pythonTransform = PythonTransform.builder() .code("a += 2; b = 'hello world'") .returnAllInputs(true) @@ -334,7 +334,42 @@ public class TestPythonTransformProcess { assertFalse(execute.isEmpty()); assertNotNull(execute.get(0)); assertNotNull(execute.get(0).get(0)); - assertEquals("hello world",execute.get(0).get(0).toString()); + assertNotNull(execute.get(0).get(1)); + assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get()); + assertEquals("hello world",execute.get(0).get(1).toString()); + } + + @Test + public void testWithSetupRun() throws Exception { + + PythonTransform pythonTransform = PythonTransform.builder() + .code("five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n") + .returnAllInputs(true) + .setupAndRun(true) + .build(); + + List> inputs = new ArrayList<>(); + inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)), + new NDArrayWritable(Nd4j.scalar(2).reshape(1,1)))); + Schema inputSchema = new Builder() + .addColumnNDArray("a",new long[]{1,1}) + .addColumnNDArray("b", new long[]{1, 1}) + .build(); + + TransformProcess tp = new TransformProcess.Builder(inputSchema) + .transform(pythonTransform) + .build(); + List> execute = LocalTransformExecutor.execute(inputs, tp); + assertFalse(execute.isEmpty()); + assertNotNull(execute.get(0)); + assertNotNull(execute.get(0).get(0)); + assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get()); } } \ No newline at end of file diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java index 24a2c2e09..dd1613d0a 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/NumpyArray.java @@ -29,6 +29,7 @@ import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; import org.nd4j.linalg.api.buffer.DataType; +import static org.nd4j.linalg.api.buffer.DataType.FLOAT; /** @@ -46,55 +47,45 @@ public class NumpyArray { private long[] strides; private DataType dtype; private INDArray nd4jArray; + static { //initialize Nd4j.scalar(1.0); - nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); + nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); } @Builder - public NumpyArray(long address, long[] shape, long strides[], boolean copy,DataType dtype) { + public NumpyArray(long address, long[] shape, long strides[], DataType dtype, boolean copy) { this.address = address; this.shape = shape; this.strides = strides; this.dtype = dtype; setND4JArray(); - if (copy){ + if (copy) { nd4jArray = nd4jArray.dup(); Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); this.address = nd4jArray.data().address(); - } } - public NumpyArray copy(){ + + + public NumpyArray copy() { return new NumpyArray(nd4jArray.dup()); } - public NumpyArray(long address, long[] shape, long strides[]){ - this(address, shape, strides, false,DataType.FLOAT); + public NumpyArray(long address, long[] shape, long strides[]) { + this(address, shape, strides, FLOAT, false); } - public NumpyArray(long address, long[] shape, long strides[], DataType dtype){ + public NumpyArray(long address, long[] shape, long strides[], DataType dtype) { this(address, shape, strides, dtype, false); } - public NumpyArray(long address, long[] shape, long strides[], DataType dtype, boolean copy){ - this.address = address; - this.shape = shape; - this.strides = strides; - this.dtype = dtype; - setND4JArray(); - if (copy){ - nd4jArray = nd4jArray.dup(); - Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); - this.address = nd4jArray.data().address(); - } - } private void setND4JArray() { long size = 1; - for(long d: shape) { + for (long d : shape) { size *= d; } Pointer ptr = nativeOps.pointerForAddress(address); @@ -107,11 +98,11 @@ public class NumpyArray { nd4jStrides[i] = strides[i] / elemSize; } - nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape,nd4jStrides,1), dtype); + nd4jArray = Nd4j.create(buff, shape, nd4jStrides, 0, Shape.getOrder(shape, nd4jStrides, 1), dtype); Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); } - public NumpyArray(INDArray nd4jArray){ + public NumpyArray(INDArray nd4jArray) { Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST); DataBuffer buff = nd4jArray.data(); address = buff.pointer().address(); @@ -119,7 +110,7 @@ public class NumpyArray { long[] nd4jStrides = nd4jArray.stride(); strides = new long[nd4jStrides.length]; int elemSize = buff.getElementSize(); - for(int i=0; i= 1,"Python code must not be empty!"); + checkNotNull("Python code must not be null!", pythonCode); + checkState(!pythonCode.isEmpty(), "Python code must not be empty!"); code = pythonCode; } - - @Override public void setInputSchema(Schema inputSchema) { this.inputSchema = inputSchema; - try{ + try { pyInputs = schemaToPythonVariables(inputSchema); PythonVariables pyOuts = new PythonVariables(); pyOuts.addInt("out"); @@ -62,17 +60,15 @@ public class PythonCondition implements Condition { .outputs(pyOuts) .build(); - } - catch (Exception e){ + } catch (Exception e) { throw new RuntimeException(e); } - } @Override - public Schema getInputSchema(){ + public Schema getInputSchema() { return inputSchema; } @@ -84,40 +80,39 @@ public class PythonCondition implements Condition { } @Override - public String outputColumnName(){ + public String outputColumnName() { return outputColumnNames()[0]; } @Override - public String[] columnNames(){ + public String[] columnNames() { return outputColumnNames(); } @Override - public String columnName(){ + public String columnName() { return outputColumnName(); } @Override - public Schema transform(Schema inputSchema){ + public Schema transform(Schema inputSchema) { return inputSchema; } @Override public boolean condition(List list) { PythonVariables inputs = getPyInputsFromWritables(list); - try{ - PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs()); + try { + pythonTransform.getPythonJob().exec(inputs, pythonTransform.getOutputs()); boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0; return ret; - } - catch (Exception e) { + } catch (Exception e) { throw new RuntimeException(e); } } @Override - public boolean condition(Object input){ + public boolean condition(Object input) { return condition(input); } @@ -135,28 +130,27 @@ public class PythonCondition implements Condition { private PythonVariables getPyInputsFromWritables(List writables) { PythonVariables ret = new PythonVariables(); - for (int i = 0; i < inputSchema.numColumns(); i++){ + for (int i = 0; i < inputSchema.numColumns(); i++) { String name = inputSchema.getName(i); Writable w = writables.get(i); - PythonVariables.Type pyType = pyInputs.getType(inputSchema.getName(i)); - switch (pyType){ + PythonType pyType = pyInputs.getType(inputSchema.getName(i)); + switch (pyType.getName()) { case INT: if (w instanceof LongWritable) { - ret.addInt(name, ((LongWritable)w).get()); - } - else { - ret.addInt(name, ((IntWritable)w).get()); + ret.addInt(name, ((LongWritable) w).get()); + } else { + ret.addInt(name, ((IntWritable) w).get()); } break; case FLOAT: - ret.addFloat(name, ((DoubleWritable)w).get()); + ret.addFloat(name, ((DoubleWritable) w).get()); break; case STR: ret.addStr(name, w.toString()); break; case NDARRAY: - ret.addNDArray(name,((NDArrayWritable)w).get()); + ret.addNDArray(name, ((NDArrayWritable) w).get()); break; } } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonContextManager.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonContextManager.java new file mode 100644 index 000000000..c3563bfc2 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonContextManager.java @@ -0,0 +1,188 @@ + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + + +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Emulates multiples interpreters in a single interpreter. + * This works by simply obfuscating/de-obfuscating variable names + * such that only the required subset of the global namespace is "visible" + * at any given time. + * By default, there exists a "main" context emulating the default interpreter + * and cannot be deleted. + * @author Fariz Rahman + */ + + +public class PythonContextManager { + + private static Set contexts = new HashSet<>(); + private static AtomicBoolean init = new AtomicBoolean(false); + private static String currentContext; + private static final String MAIN_CONTEXT = "main"; + static { + init(); + } + + private static void init() { + if (init.get()) return; + new PythonExecutioner(); + init.set(true); + currentContext = MAIN_CONTEXT; + contexts.add(currentContext); + } + + + public static void addContext(String contextName) throws PythonException { + if (!validateContextName(contextName)) { + throw new PythonException("Invalid context name: " + contextName); + } + contexts.add(contextName); + } + + public static boolean hasContext(String contextName) { + return contexts.contains(contextName); + } + + + public static boolean validateContextName(String s) { + if (s.length() == 0) return false; + if (!Character.isJavaIdentifierStart(s.charAt(0))) return false; + for (int i = 1; i < s.length(); i++) + if (!Character.isJavaIdentifierPart(s.charAt(i))) + return false; + return true; + } + + private static String getContextPrefix(String contextName) { + return "__collapsed__" + contextName + "__"; + } + + private static String getCollapsedVarNameForContext(String varName, String contextName) { + return getContextPrefix(contextName) + varName; + } + + private static String expandCollapsedVarName(String varName, String contextName) { + String prefix = "__collapsed__" + contextName + "__"; + return varName.substring(prefix.length()); + + } + + private static void collapseContext(String contextName) { + PythonObject globals = Python.globals(); + PythonObject keysList = Python.list(globals.attr("keys").call()); + int numKeys = Python.len(keysList).toInt(); + for (int i = 0; i < numKeys; i++) { + PythonObject key = keysList.get(i); + String keyStr = key.toString(); + if (!((keyStr.startsWith("__") && keyStr.endsWith("__")) || keyStr.startsWith("__collapsed_"))) { + String collapsedKey = getCollapsedVarNameForContext(keyStr, contextName); + PythonObject val = globals.attr("pop").call(key); + globals.set(new PythonObject(collapsedKey), val); + } + } + } + + private static void expandContext(String contextName) { + String prefix = getContextPrefix(contextName); + PythonObject globals = Python.globals(); + PythonObject keysList = Python.list(globals.attr("keys").call()); + int numKeys = Python.len(keysList).toInt(); + for (int i = 0; i < numKeys; i++) { + PythonObject key = keysList.get(i); + String keyStr = key.toString(); + if (keyStr.startsWith(prefix)) { + String expandedKey = expandCollapsedVarName(keyStr, contextName); + PythonObject val = globals.attr("pop").call(key); + globals.set(new PythonObject(expandedKey), val); + } + } + + } + + public static void setContext(String contextName) throws PythonException{ + if (contextName.equals(currentContext)) { + return; + } + if (!hasContext(contextName)) { + addContext(contextName); + } + collapseContext(currentContext); + expandContext(contextName); + currentContext = contextName; + + } + + public static void setMainContext() { + try{ + setContext(MAIN_CONTEXT); + } + catch (PythonException pe){ + throw new RuntimeException(pe); + } + + } + + public static String getCurrentContext() { + return currentContext; + } + + public static void deleteContext(String contextName) throws PythonException { + if (contextName.equals(MAIN_CONTEXT)) { + throw new PythonException("Can not delete main context!"); + } + if (contextName.equals(currentContext)) { + throw new PythonException("Can not delete current context!"); + } + String prefix = getContextPrefix(contextName); + PythonObject globals = Python.globals(); + PythonObject keysList = Python.list(globals.attr("keys").call()); + int numKeys = Python.len(keysList).toInt(); + for (int i = 0; i < numKeys; i++) { + PythonObject key = keysList.get(i); + String keyStr = key.toString(); + if (keyStr.startsWith(prefix)) { + globals.attr("__delitem__").call(key); + } + } + contexts.remove(contextName); + } + + public static void deleteNonMainContexts() { + try{ + setContext(MAIN_CONTEXT); // will never fail + for (String c : contexts.toArray(new String[0])) { + if (!c.equals(MAIN_CONTEXT)) { + deleteContext(c); // will never fail + } + } + }catch(Exception e){ + throw new RuntimeException(e); + } + } + + public String[] getContexts() { + return contexts.toArray(new String[0]); + } + +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonException.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonException.java new file mode 100644 index 000000000..d66c67a32 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonException.java @@ -0,0 +1,44 @@ + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + +/** + * Thrown when an exception occurs in python land + */ +public class PythonException extends Exception { + public PythonException(String message){ + super(message); + } + private static String getExceptionString(PythonObject exception){ + if (Python.isinstance(exception, Python.ExceptionType())){ + String exceptionClass = Python.type(exception).attr("__name__").toString(); + String message = exception.toString(); + return exceptionClass + ": " + message; + } + return exception.toString(); + } + public PythonException(PythonObject exception){ + this(getExceptionString(exception)); + } + public PythonException(String message, Throwable cause){ + super(message, cause); + } + public PythonException(Throwable cause){ + super(cause); + } +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java index 8afec3b5a..a06e60e98 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonExecutioner.java @@ -1,5 +1,6 @@ + /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -14,33 +15,29 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ + package org.datavec.python; - -import java.io.*; -import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - - import lombok.extern.slf4j.Slf4j; -import org.apache.commons.io.FileUtils; import org.apache.commons.io.IOUtils; -import org.json.JSONObject; -import org.json.JSONArray; -import org.bytedeco.javacpp.*; -import org.bytedeco.cpython.*; -import static org.bytedeco.cpython.global.python.*; -import org.bytedeco.numpy.global.numpy; - -import static org.datavec.python.PythonUtils.*; - -import org.nd4j.base.Preconditions; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.factory.Nd4j; +import org.bytedeco.cpython.PyThreadState; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.numpy.global.numpy; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.io.ClassPathResource; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.bytedeco.cpython.global.python.*; +import static org.bytedeco.cpython.global.python.PyThreadState_Get; +import static org.datavec.python.Python.*; /** * Allows execution of python scripts managed by @@ -58,18 +55,14 @@ import org.nd4j.linalg.io.ClassPathResource; * whether it is appended, prepended, or not used. This behavior is useful when you need to use an external * python distribution such as anaconda. * - * 3. A main interpreter: This is the default interpreter to be used with the main thread. - * We may initialize one or more relative to the thread invoking the python code. - * - * 4. A proper numpy import for use with javacpp: We call numpy import ourselves to ensure proper loading of + * 3. A proper numpy import for use with javacpp: We call numpy import ourselves to ensure proper loading of * native libraries needed by numpy are allowed to load in the proper order. If we don't do this, - * it causes a variety of issues with running numpy. + * it causes a variety of issues with running numpy. (User must still include "import numpy as np" in their scripts). * - * 5. Various python scripts pre defined on the classpath included right with the java code. + * 4. Various python scripts pre defined on the classpath included right with the java code. * These are auxillary python scripts used for loading classes, pre defining certain kinds of behavior * in order for us to manipulate values within the python memory, as well as pulling them out of memory - * for integration within the internal python executioner. You can see this behavior in {@link #_readOutputs(PythonVariables)} - * as an example. + * for integration within the internal python executioner. * * For more information on how this works, please take a look at the {@link #init()} * method. @@ -94,80 +87,242 @@ import org.nd4j.linalg.io.ClassPathResource; * * * @author Fariz Rahman - * @author Adam Gibson + * @author Adam Gibson */ -/** - * Allows execution of python scripts managed by - * an internal interpreter. - * An end user may specify a python script to run - * via any of the execution methods available in this class. - * - * At static initialization time (when the class is first initialized) - * a number of components are setup: - * 1. The python path. A user may over ride this with the system property {@link #DEFAULT_PYTHON_PATH_PROPERTY} - * - * 2. Since this executioner uses javacpp to manage and run python interpreters underneath the covers, - * a user may also over ride the system property {@link #JAVACPP_PYTHON_APPEND_TYPE} with one of the {@link JavaCppPathType} - * values. This will allow the user to determine whether the javacpp default python path is used at all, and if so - * whether it is appended, prepended, or not used. This behavior is useful when you need to use an external - * python distribution such as anaconda. - * - * 3. A main interpreter: This is the default interpreter to be used with the main thread. - * We may initialize one or more relative to the thread invoking the python code. - * - * 4. A proper numpy import for use with javacpp: We call numpy import ourselves to ensure proper loading of - * native libraries needed by numpy are allowed to load in the proper order. If we don't do this, - * it causes a variety of issues with running numpy. - * - * 5. Various python scripts pre defined on the classpath included right with the java code. - * These are auxillary python scripts used for loading classes, pre defining certain kinds of behavior - * in order for us to manipulate values within the python memory, as well as pulling them out of memory - * for integration within the internal python executioner. You can see this behavior in {@link #_readOutputs(PythonVariables)} - * as an example. - * - * For more information on how this works, please take a look at the {@link #init()} - * method. - * - * Generally, a user defining a python script for use by the python executioner - * will have a set of defined target input values and output values. - * These values should not be present when actually running the script, but just referenced. - * In order to test your python script for execution outside the engine, - * we recommend commenting out a few default values as dummy input values. - * This will allow an end user to test their script before trying to use the server. - * - * In order to get output values out of a python script, all a user has to do - * is define the output variables they want being used in the final output in the actual pipeline. - * For example, if a user wants to return a dictionary, they just have to create a dictionary with that name - * and based on the configured {@link PythonVariables} passed as outputs - * to one of the execution methods, we can pull the values out automatically. - * - * For input definitions, it is similar. You just define the values you want used in - * {@link PythonVariables} and we will automatically generate code for defining those values - * as desired for running. This allows the user to customize values dynamically - * at runtime but reference them by name in a python script. - * - * - * @author Fariz Rahman - * @author Adam Gibson - */ @Slf4j public class PythonExecutioner { - private final static String fileVarName = "_f" + Nd4j.getRandom().nextInt(); - private static boolean init; + + private static AtomicBoolean init = new AtomicBoolean(false); public final static String DEFAULT_PYTHON_PATH_PROPERTY = "org.datavec.python.path"; public final static String JAVACPP_PYTHON_APPEND_TYPE = "org.datavec.python.javacpp.path.append"; public final static String DEFAULT_APPEND_TYPE = "before"; - private static Map interpreters = new java.util.concurrent.ConcurrentHashMap<>(); - private static PyThreadState currentThreadState; - private static PyThreadState mainThreadState; - public final static String ALL_VARIABLES_KEY = "allVariables"; - public final static String MAIN_INTERPRETER_NAME = "main"; - private static String clearVarsCode; + private final static String PYTHON_EXCEPTION_KEY = "__python_exception__"; - private static String currentInterpreter = MAIN_INTERPRETER_NAME; + static { + init(); + } + + private static synchronized void init() { + if (init.get()) { + return; + } + initPythonPath(); + init.set(true); + log.info("CPython: PyEval_InitThreads()"); + PyEval_InitThreads(); + log.info("CPython: Py_InitializeEx()"); + Py_InitializeEx(0); + numpy._import_array(); + } + + private static synchronized void simpleExec(String code) throws PythonException{ + log.debug(code); + log.info("CPython: PyRun_SimpleStringFlag()"); + + int result = PyRun_SimpleStringFlags(code, null); + if (result != 0) { + throw new PythonException("Execution failed, unable to retrieve python exception."); + } + } + + public static boolean validateVariableName(String s) { + if (s.isEmpty()) return false; + if (!Character.isJavaIdentifierStart(s.charAt(0))) return false; + for (int i = 1; i < s.length(); i++) + if (!Character.isJavaIdentifierPart(s.charAt(i))) + return false; + return true; + } + + + /** + * Sets a variable in the global scope of the current context (See @PythonContextManager). + * This is equivalent to `exec("a = b");` where a is the variable name + * and b is the variable value. + * @param varName Name of the python variable being set. Should be a valid python identifier string + * @param pythonObject Value for the python variable + * @throws Exception + */ + public static void setVariable(String varName, PythonObject pythonObject) throws PythonException{ + if (!validateVariableName(varName)){ + throw new PythonException("Invalid variable name: " + varName); + } + Python.globals().set(new PythonObject(varName), pythonObject); + } + + public static void setVariable(String varName, PythonType varType, Object value) throws PythonException { + PythonObject pythonObject; + switch (varType.getName()) { + case STR: + pythonObject = new PythonObject(PythonType.STR.convert(value)); + break; + case INT: + pythonObject = new PythonObject(PythonType.INT.convert(value)); + break; + case FLOAT: + pythonObject = new PythonObject(PythonType.FLOAT.convert(value)); + break; + case BOOL: + pythonObject = new PythonObject(PythonType.BOOL.convert(value)); + break; + case NDARRAY: + pythonObject = new PythonObject(PythonType.NDARRAY.convert(value)); + break; + case LIST: + pythonObject = new PythonObject(PythonType.LIST.convert(value)); + break; + case DICT: + pythonObject = new PythonObject(PythonType.DICT.convert(value)); + break; + case BYTES: + pythonObject = new PythonObject(PythonType.BYTES.convert(value)); + break; + default: + throw new PythonException("Unsupported type: " + varType); + + } + setVariable(varName, pythonObject); + } + + public static void setVariables(PythonVariables pyVars) throws PythonException{ + if (pyVars == null) return; + for (String varName : pyVars.getVariables()) { + setVariable(varName, pyVars.getType(varName), pyVars.getValue(varName)); + } + } + + public static PythonObject getVariable(String varName) { + return Python.globals().attr("get").call(varName); + } + + public static T getVariable(String varName, PythonType varType) throws PythonException{ + PythonObject pythonObject = getVariable(varName); + return varType.toJava(pythonObject); + } + + public static void getVariables(PythonVariables pyVars) throws PythonException { + for (String varName : pyVars.getVariables()) { + pyVars.setValue(varName, getVariable(varName, pyVars.getType(varName))); + } + } + + + private static String getWrappedCode(String code) { + try (InputStream is = new ClassPathResource("pythonexec/pythonexec.py").getInputStream()) { + String base = IOUtils.toString(is, Charset.defaultCharset()); + StringBuffer indentedCode = new StringBuffer(); + for (String split : code.split("\n")) { + indentedCode.append(" " + split + "\n"); + + } + + String out = base.replace(" pass", indentedCode); + return out; + } catch (IOException e) { + throw new IllegalStateException("Unable to read python code!", e); + } + + } + + private static void throwIfExecutionFailed() throws PythonException{ + PythonObject ex = getVariable(PYTHON_EXCEPTION_KEY); + if (ex != null && !ex.toString().isEmpty()){ + setVariable(PYTHON_EXCEPTION_KEY, new PythonObject("")); + throw new PythonException(ex); + } + } + + public static void exec(String code) throws PythonException { + simpleExec(getWrappedCode(code)); + throwIfExecutionFailed(); + } + + public static void exec(String code, PythonVariables outputVariables)throws PythonException { + simpleExec(getWrappedCode(code)); + throwIfExecutionFailed(); + getVariables(outputVariables); + } + + public static void exec(String code, PythonVariables inputVariables, PythonVariables outputVariables) throws PythonException { + setVariables(inputVariables); + simpleExec(getWrappedCode(code)); + throwIfExecutionFailed(); + getVariables(outputVariables); + } + + public static PythonVariables execAndReturnAllVariables(String code) throws PythonException { + simpleExec(getWrappedCode(code)); + throwIfExecutionFailed(); + PythonVariables out = new PythonVariables(); + PythonObject globals = Python.globals(); + PythonObject keysList = Python.list(globals.attr("keys")); + int numKeys = Python.len(keysList).toInt(); + for (int i = 0; i < numKeys; i++) { + PythonObject key = keysList.get(i); + String keyStr = key.toString(); + if (!keyStr.startsWith("_")) { + PythonObject val = globals.get(key); + if (Python.isinstance(val, intType())) { + out.addInt(keyStr, val.toInt()); + } else if (Python.isinstance(val, floatType())) { + out.addFloat(keyStr, val.toDouble()); + } else if (Python.isinstance(val, strType())) { + out.addStr(keyStr, val.toString()); + } else if (Python.isinstance(val, boolType())) { + out.addBool(keyStr, val.toBoolean()); + } else if (Python.isinstance(val, listType())) { + out.addList(keyStr, val.toList().toArray(new Object[0])); + } else if (Python.isinstance(val, dictType())) { + out.addDict(keyStr, val.toMap()); + } + } + } + return out; + + } + + public static PythonVariables getAllVariables() throws PythonException{ + PythonVariables out = new PythonVariables(); + PythonObject globals = Python.globals(); + PythonObject keysList = Python.list(globals.attr("keys").call()); + int numKeys = Python.len(keysList).toInt(); + for (int i = 0; i < numKeys; i++) { + PythonObject key = keysList.get(i); + String keyStr = key.toString(); + if (!keyStr.startsWith("_")) { + PythonObject val = globals.get(key); + if (Python.isinstance(val, intType())) { + out.addInt(keyStr, val.toInt()); + } else if (Python.isinstance(val, floatType())) { + out.addFloat(keyStr, val.toDouble()); + } else if (Python.isinstance(val, strType())) { + out.addStr(keyStr, val.toString()); + } else if (Python.isinstance(val, boolType())) { + out.addBool(keyStr, val.toBoolean()); + } else if (Python.isinstance(val, listType())) { + out.addList(keyStr, val.toList().toArray(new Object[0])); + } else if (Python.isinstance(val, dictType())) { + out.addDict(keyStr, val.toMap()); + } else { + PythonObject np = importModule("numpy"); + if (Python.isinstance(val, np.attr("ndarray"), np.attr("generic"))) { + out.addNDArray(keyStr, val.toNumpy()); + } + } + + } + } + return out; + } + + public static PythonVariables execAndReturnAllVariables(String code, PythonVariables inputs) throws Exception{ + setVariables(inputs); + simpleExec(getWrappedCode(code)); + return getAllVariables(); + } /** * One of a few desired values @@ -178,7 +333,7 @@ public class PythonExecutioner { * NONE: Don't use javacpp's python path at all */ public enum JavaCppPathType { - BEFORE,AFTER,NONE + BEFORE, AFTER, NONE } /** @@ -186,36 +341,36 @@ public class PythonExecutioner { * Generally you can just use the PYTHONPATH environment variable, * but if you need to set it from code, this can work as well. */ - public static synchronized void setPythonPath() { - if(!init) { + + public static synchronized void initPythonPath() { + if (!init.get()) { try { String path = System.getProperty(DEFAULT_PYTHON_PATH_PROPERTY); - if(path == null) { + if (path == null) { log.info("Setting python default path"); File[] packages = numpy.cachePackages(); Py_SetPath(packages); - } - else { + } else { log.info("Setting python path " + path); StringBuffer sb = new StringBuffer(); File[] packages = numpy.cachePackages(); - JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE,DEFAULT_APPEND_TYPE).toUpperCase()); - switch(pathAppendValue) { + JavaCppPathType pathAppendValue = JavaCppPathType.valueOf(System.getProperty(JAVACPP_PYTHON_APPEND_TYPE, DEFAULT_APPEND_TYPE).toUpperCase()); + switch (pathAppendValue) { case BEFORE: - for(File cacheDir : packages) { + for (File cacheDir : packages) { sb.append(cacheDir); sb.append(java.io.File.pathSeparator); } sb.append(path); - log.info("Prepending javacpp python path " + sb.toString()); + log.info("Prepending javacpp python path: {}", sb.toString()); break; case AFTER: sb.append(path); - for(File cacheDir : packages) { + for (File cacheDir : packages) { sb.append(cacheDir); sb.append(java.io.File.pathSeparator); } @@ -229,916 +384,15 @@ public class PythonExecutioner { } //prepend the javacpp packages - log.info("Final python path " + sb.toString()); + log.info("Final python path: {}", sb.toString()); Py_SetPath(sb.toString()); } } catch (IOException e) { log.error("Failed to set python path.", e); } - } - else { + } else { throw new IllegalStateException("Unable to reset python path. Already initialized."); } } - - /** - * Initialize the name space and the python execution - * Calling this method more than once will be a no op - */ - public static synchronized void init() { - if(init) { - return; - } - - try(InputStream is = new org.nd4j.linalg.io.ClassPathResource("pythonexec/clear_vars.py").getInputStream()) { - clearVarsCode = IOUtils.toString(new java.io.InputStreamReader(is)); - } catch (java.io.IOException e) { - throw new IllegalStateException("Unable to read pythonexec/clear_vars.py"); - } - - log.info("CPython: PyEval_InitThreads()"); - PyEval_InitThreads(); - log.info("CPython: Py_InitializeEx()"); - Py_InitializeEx(0); - log.info("CPython: PyGILState_Release()"); - init = true; - interpreters.put(MAIN_INTERPRETER_NAME, PyThreadState_Get()); - numpy._import_array(); - applyPatches(); - } - - - /** - * Run {@link #resetInterpreter(String)} - * on all interpreters. - */ - public static void resetAllInterpreters() { - for(String interpreter : interpreters.keySet()) { - resetInterpreter(interpreter); - } - } - - /** - * Reset the main interpreter. - * For more information see {@link #resetInterpreter(String)} - */ - public static void resetMainInterpreter() { - resetInterpreter(MAIN_INTERPRETER_NAME); - } - - /** - * Reset the interpreter with the given name. - * Runs pythonexec/clear_vars.py - * For more information see: - * https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script - * @param interpreterName the interpreter name to - * reset - */ - public static synchronized void resetInterpreter(String interpreterName) { - Preconditions.checkState(hasInterpreter(interpreterName)); - log.info("Resetting interpreter " + interpreterName); - String oldInterpreter = currentInterpreter; - setInterpreter(interpreterName); - exec("pass"); - //exec(interpreterName); // ?? - setInterpreter(oldInterpreter); - } - - /** - * Clear the non main intrepreters. - */ - public static void clearNonMainInterpreters() { - for(String key : interpreters.keySet()) { - if(!key.equals(MAIN_INTERPRETER_NAME)) { - deleteInterpreter(key); - } - } - } - - public static PythonVariables defaultPythonVariableOutput() { - PythonVariables ret = new PythonVariables(); - ret.add(ALL_VARIABLES_KEY, PythonVariables.Type.DICT); - return ret; - } - - /** - * Return the python path being used. - * @return a string specifying the python path in use - */ - public static String getPythonPath() { - return new BytePointer(Py_GetPath()).getString(); - } - - - static { - setPythonPath(); - init(); - } - - - /* ---------sub-interpreter and gil management-----------*/ - public static void setInterpreter(String interpreterName) { - if (!hasInterpreter(interpreterName)){ - PyThreadState main = PyThreadState_Get(); - PyThreadState ts = Py_NewInterpreter(); - - interpreters.put(interpreterName, ts); - PyThreadState_Swap(main); - } - - currentInterpreter = interpreterName; - } - - /** - * Returns the current interpreter. - * @return - */ - public static String getInterpreter() { - return currentInterpreter; - } - - - public static boolean hasInterpreter(String interpreterName){ - return interpreters.containsKey(interpreterName); - } - - public static void deleteInterpreter(String interpreterName) { - if (interpreterName.equals("main")){ - throw new IllegalArgumentException("Can not delete main interpreter"); - } - - Py_EndInterpreter(interpreters.remove(interpreterName)); - } - - private static synchronized void acquireGIL() { - log.info("acquireGIL()"); - log.info("CPython: PyEval_SaveThread()"); - mainThreadState = PyEval_SaveThread(); - log.info("CPython: PyThreadState_New()"); - currentThreadState = PyThreadState_New(interpreters.get(currentInterpreter).interp()); - log.info("CPython: PyEval_RestoreThread()"); - PyEval_RestoreThread(currentThreadState); - log.info("CPython: PyThreadState_Swap()"); - PyThreadState_Swap(currentThreadState); - - } - - private static synchronized void releaseGIL() { - log.info("CPython: PyEval_SaveThread()"); - PyEval_SaveThread(); - log.info("CPython: PyEval_RestoreThread()"); - PyEval_RestoreThread(mainThreadState); - } - - /* -------------------*/ - /** - * Print the python version to standard out. - */ - public static void printPythonVersion() { - exec("import sys; print(sys.version) sys.stdout.flush();"); - } - - - - private static String inputCode(PythonVariables pyInputs)throws Exception { - String inputCode = ""; - if (pyInputs == null){ - return inputCode; - } - - Map strInputs = pyInputs.getStrVariables(); - Map intInputs = pyInputs.getIntVariables(); - Map floatInputs = pyInputs.getFloatVariables(); - Map ndInputs = pyInputs.getNdVars(); - Map listInputs = pyInputs.getListVariables(); - Map fileInputs = pyInputs.getFileVariables(); - Map> dictInputs = pyInputs.getDictVariables(); - - String[] varNames; - - - varNames = strInputs.keySet().toArray(new String[strInputs.size()]); - for(String varName: varNames) { - Preconditions.checkNotNull(varName,"Var name is null!"); - Preconditions.checkNotNull(varName.isEmpty(),"Var name can not be empty!"); - String varValue = strInputs.get(varName); - //inputCode += varName + "= {}\n"; - if(varValue != null) - inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n"; - else { - inputCode += varName + " = ''\n"; - } - } - - varNames = intInputs.keySet().toArray(new String[intInputs.size()]); - for(String varName: varNames) { - Long varValue = intInputs.get(varName); - if(varValue != null) - inputCode += varName + " = " + varValue.toString() + "\n"; - else { - inputCode += " = 0\n"; - } - } - - varNames = dictInputs.keySet().toArray(new String[dictInputs.size()]); - for(String varName: varNames) { - Map varValue = dictInputs.get(varName); - if(varValue != null) { - throw new IllegalArgumentException("Unable to generate input code for dictionaries."); - } - else { - inputCode += " = {}\n"; - } - } - - varNames = floatInputs.keySet().toArray(new String[floatInputs.size()]); - for(String varName: varNames){ - Double varValue = floatInputs.get(varName); - if(varValue != null) - inputCode += varName + " = " + varValue.toString() + "\n"; - else { - inputCode += varName + " = 0.0\n"; - } - } - - varNames = listInputs.keySet().toArray(new String[listInputs.size()]); - for (String varName: varNames) { - Object[] varValue = listInputs.get(varName); - if(varValue != null) { - String listStr = jArrayToPyString(varValue); - inputCode += varName + " = " + listStr + "\n"; - } - else { - inputCode += varName + " = []\n"; - } - - } - - varNames = fileInputs.keySet().toArray(new String[fileInputs.size()]); - for(String varName: varNames) { - String varValue = fileInputs.get(varName); - if(varValue != null) - inputCode += varName + " = \"\"\"" + escapeStr(varValue) + "\"\"\"\n"; - else { - inputCode += varName + " = ''\n"; - } - } - - if (!ndInputs.isEmpty()) { - inputCode += "import ctypes\n\nimport sys\nimport numpy as np\n"; - varNames = ndInputs.keySet().toArray(new String[ndInputs.size()]); - - String converter = "__arr_converter = lambda addr, shape, type: np.ctypeslib.as_array(ctypes.cast(addr, ctypes.POINTER(type)), shape)\n"; - inputCode += converter; - for(String varName: varNames) { - NumpyArray npArr = ndInputs.get(varName); - if(npArr == null) - continue; - - npArr = npArr.copy(); - String shapeStr = "("; - for (long d: npArr.getShape()){ - shapeStr += d + ","; - } - shapeStr += ")"; - String code; - String ctype; - if (npArr.getDtype() == DataType.FLOAT) { - - ctype = "ctypes.c_float"; - } - else if (npArr.getDtype() == DataType.DOUBLE) { - ctype = "ctypes.c_double"; - } - else if (npArr.getDtype() == DataType.SHORT) { - ctype = "ctypes.c_int16"; - } - else if (npArr.getDtype() == DataType.INT) { - ctype = "ctypes.c_int32"; - } - else if (npArr.getDtype() == DataType.LONG){ - ctype = "ctypes.c_int64"; - } - else{ - throw new Exception("Unsupported data type: " + npArr.getDtype().toString() + "."); - } - - code = "__arr_converter(" + npArr.getAddress() + "," + shapeStr + "," + ctype + ")"; - code = varName + "=" + code + "\n"; - inputCode += code; - } - - } - return inputCode; - } - - - private static synchronized void _readOutputs(PythonVariables pyOutputs) throws IOException { - File f = new File(getTempFile()); - Preconditions.checkState(f.exists(),"File " + f.getAbsolutePath() + " failed to get written for reading outputs!"); - String json = FileUtils.readFileToString(f, Charset.defaultCharset()); - log.info("Executioner output: "); - log.info(json); - f.delete(); - - if(json.isEmpty()) { - log.warn("No json found fore reading outputs. Returning."); - return; - } - - try { - JSONObject jobj = new JSONObject(json); - for (String varName: pyOutputs.getVariables()) { - PythonVariables.Type type = pyOutputs.getType(varName); - if (type == PythonVariables.Type.NDARRAY) { - JSONObject varValue = (JSONObject)jobj.get(varName); - long address = (Long) varValue.getLong("address"); - JSONArray shapeJson = (JSONArray) varValue.get("shape"); - JSONArray stridesJson = (JSONArray) varValue.get("strides"); - long[] shape = jsonArrayToLongArray(shapeJson); - long[] strides = jsonArrayToLongArray(stridesJson); - String dtypeName = (String)varValue.get("dtype"); - DataType dtype; - if (dtypeName.equals("float64")) { - dtype = DataType.DOUBLE; - } - else if (dtypeName.equals("float32")) { - dtype = DataType.FLOAT; - } - else if (dtypeName.equals("int16")) { - dtype = DataType.SHORT; - } - else if (dtypeName.equals("int32")) { - dtype = DataType.INT; - } - else if (dtypeName.equals("int64")) { - dtype = DataType.LONG; - } - else{ - throw new Exception("Unsupported array type " + dtypeName + "."); - } - - pyOutputs.setValue(varName, new NumpyArray(address, shape, strides, dtype, true)); - - } - else if (type == PythonVariables.Type.LIST) { - JSONArray varValue = (JSONArray) jobj.get(varName); - pyOutputs.setValue(varName, varValue); - } - else if (type == PythonVariables.Type.DICT) { - Map map = toMap((JSONObject) jobj.get(varName)); - pyOutputs.setValue(varName, map); - - } - else{ - pyOutputs.setValue(varName, jobj.get(varName)); - } - } - } - catch (Exception e){ - throw new RuntimeException(e); - } - } - - - - - private static synchronized void _exec(String code) { - log.debug(code); - log.info("CPython: PyRun_SimpleStringFlag()"); - - int result = PyRun_SimpleStringFlags(code, null); - if (result != 0) { - log.info("CPython: PyErr_Print"); - PyErr_Print(); - throw new RuntimeException("exec failed"); - } - } - - private static synchronized void _exec_wrapped(String code) { - _exec(getWrappedCode(code)); - } - - /** - * Executes python code. Also manages python thread state. - * @param code the code to run - */ - - public static void exec(String code) { - code = getWrappedCode(code); - if(code.contains("import numpy") && !getInterpreter().equals("main")) {// FIXME - throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); - } - - acquireGIL(); - _exec(code); - log.info("Exec done"); - releaseGIL(); - } - - private static boolean _hasGlobalVariable(String varName){ - PyObject mainModule = PyImport_AddModule("__main__"); - PyObject var = PyObject_GetAttrString(mainModule, varName); - boolean hasVar = var != null; - Py_DecRef(var); - return hasVar; - } - - /** - * Executes python code and looks for methods setup() and run() - * If both setup() and run() are found, both are executed for the first - * time and for subsequent calls only run() is executed. - */ - public static void execWithSetupAndRun(String code) { - code = getWrappedCode(code); - if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME - throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); - } - - acquireGIL(); - _exec(code); - if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){ - log.debug("setup() and run() methods found."); - if (!_hasGlobalVariable("__setup_done__")){ - log.debug("Calling setup()..."); - _exec("setup()"); - _exec("__setup_done__ = True"); - } - log.debug("Calling run()..."); - _exec("run()"); - } - log.info("Exec done"); - releaseGIL(); - } - - /** - * Executes python code and looks for methods setup() and run() - * If both setup() and run() are found, both are executed for the first - * time and for subsequent calls only run() is executed. - */ - public static void execWithSetupAndRun(String code, PythonVariables pyOutputs) { - code = getWrappedCode(code); - if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME - throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); - } - - acquireGIL(); - _exec(code); - if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){ - log.debug("setup() and run() methods found."); - if (!_hasGlobalVariable("__setup_done__")){ - log.debug("Calling setup()..."); - _exec("setup()"); - _exec("__setup_done__ = True"); - } - log.debug("Calling run()..."); - _exec("__out = run();for (k,v) in __out.items(): globals()[k]=v"); - } - log.info("Exec done"); - try { - - _readOutputs(pyOutputs); - - } catch (IOException e) { - log.error("Failed to read outputs", e); - } - - releaseGIL(); - } - - /** - * Run the given code with the given python outputs - * @param code the code to run - * @param pyOutputs the outputs to run - */ - public static void exec(String code, PythonVariables pyOutputs) { - - exec(code + '\n' + outputCode(pyOutputs)); - try { - - _readOutputs(pyOutputs); - - } catch (IOException e) { - log.error("Failed to read outputs", e); - } - - releaseGIL(); - } - - - /** - * Execute the given python code with the given - * {@link PythonVariables} as inputs and outputs - * @param code the code to run - * @param pyInputs the inputs to the code - * @param pyOutputs the outputs to the code - * @throws Exception - */ - public static void exec(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception { - String inputCode = inputCode(pyInputs); - exec(inputCode + code, pyOutputs); - } - - /** - * Execute the given python code - * with the {@link PythonVariables} - * inputs and outputs for storing the values - * specified by the user and needed by the user - * as output - * @param code the python code to execute - * @param pyInputs the python variables input in to the python script - * @param pyOutputs the python variables output returned by the python script - * @throws Exception - */ - public static void execWithSetupAndRun(String code, PythonVariables pyInputs, PythonVariables pyOutputs) throws Exception { - String inputCode = inputCode(pyInputs); - code = inputCode +code; - code = getWrappedCode(code); - if(code.contains("import numpy") && !getInterpreter().equals("main")) { // FIXME - throw new IllegalArgumentException("Unable to execute numpy on sub interpreter. See https://mail.python.org/pipermail/python-dev/2019-January/156095.html for the reasons."); - } - acquireGIL(); - _exec(code); - if (_hasGlobalVariable("setup") && _hasGlobalVariable("run")){ - log.debug("setup() and run() methods found."); - if (!_hasGlobalVariable("__setup_done__")){ - releaseGIL(); // required - acquireGIL(); - log.debug("Calling setup()..."); - _exec("setup()"); - _exec("__setup_done__ = True"); - }else{ - log.debug("setup() already called once."); - } - log.debug("Calling run()..."); - releaseGIL(); // required - acquireGIL(); - _exec("import inspect\n"+ - "__out = run(**{k:globals()[k]for k in inspect.getfullargspec(run).args})\n"+ - "globals().update(__out)"); - } - releaseGIL(); // required - acquireGIL(); - _exec(outputCode(pyOutputs)); - log.info("Exec done"); - try { - - _readOutputs(pyOutputs); - - } catch (IOException e) { - log.error("Failed to read outputs", e); - } - - releaseGIL(); - } - - - - private static String interpreterNameFromTransform(PythonTransform transform){ - return transform.getName().replace("-", "_"); - } - - - /** - * Run a {@link PythonTransform} with the given inputs - * @param transform the transform to run - * @param inputs the inputs to the transform - * @return the output variables - * @throws Exception - */ - public static PythonVariables exec(PythonTransform transform, PythonVariables inputs)throws Exception { - String name = interpreterNameFromTransform(transform); - setInterpreter(name); - Preconditions.checkNotNull(transform.getOutputs(),"Transform outputs were null!"); - exec(transform.getCode(), inputs, transform.getOutputs()); - return transform.getOutputs(); - } - public static PythonVariables execWithSetupAndRun(PythonTransform transform, PythonVariables inputs)throws Exception { - String name = interpreterNameFromTransform(transform); - setInterpreter(name); - Preconditions.checkNotNull(transform.getOutputs(),"Transform outputs were null!"); - execWithSetupAndRun(transform.getCode(), inputs, transform.getOutputs()); - return transform.getOutputs(); - } - - - /** - * Run the code and return the outputs - * @param code the code to run - * @return all python variables - */ - public static PythonVariables execAndReturnAllVariables(String code) { - exec(code + '\n' + outputCodeForAllVariables()); - PythonVariables allVars = new PythonVariables(); - allVars.addDict(ALL_VARIABLES_KEY); - try { - _readOutputs(allVars); - }catch (IOException e) { - log.error("Failed to read outputs", e); - } - - return expandInnerDict(allVars, ALL_VARIABLES_KEY); - } - public static PythonVariables execWithSetupRunAndReturnAllVariables(String code) { - execWithSetupAndRun(code + '\n' + outputCodeForAllVariables()); - PythonVariables allVars = new PythonVariables(); - allVars.addDict(ALL_VARIABLES_KEY); - try { - _readOutputs(allVars); - }catch (IOException e) { - log.error("Failed to read outputs", e); - } - - return expandInnerDict(allVars, ALL_VARIABLES_KEY); - } - - /** - * - * @param code code string to run - * @param pyInputs python input variables - * @return all python variables - * @throws Exception throws when there's an issue while execution of python code - */ - public static PythonVariables execAndReturnAllVariables(String code, PythonVariables pyInputs) throws Exception { - String inputCode = inputCode(pyInputs); - return execAndReturnAllVariables(inputCode + code); - } - public static PythonVariables execWithSetupRunAndReturnAllVariables(String code, PythonVariables pyInputs) throws Exception { - String inputCode = inputCode(pyInputs); - return execWithSetupRunAndReturnAllVariables(inputCode + code); - } - - - /** - * Evaluate a string based on the - * current variable name. - * This variable named needs to be present - * or defined earlier in python code - * in order to pull out the values. - * - * @param varName the variable name to evaluate - * @return the evaluated value - */ - public static String evalString(String varName) { - PythonVariables vars = new PythonVariables(); - vars.addStr(varName); - exec("print('')", vars); - return vars.getStrValue(varName); - } - - - - /** - * Evaluate a string based on the - * current variable name. - * This variable named needs to be present - * or defined earlier in python code - * in order to pull out the values. - * - * @param varName the variable name to evaluate - * @return the evaluated value - */ - public static long evalInteger(String varName) { - PythonVariables vars = new PythonVariables(); - vars.addInt(varName); - exec("print('')", vars); - return vars.getIntValue(varName); - } - - - /** - * Evaluate a string based on the - * current variable name. - * This variable named needs to be present - * or defined earlier in python code - * in order to pull out the values. - * - * @param varName the variable name to evaluate - * @return the evaluated value - */ - public static Double evalFloat(String varName) { - PythonVariables vars = new PythonVariables(); - vars.addFloat(varName); - exec("print('')", vars); - return vars.getFloatValue(varName); - } - - - /** - * Evaluate a string based on the - * current variable name. - * This variable named needs to be present - * or defined earlier in python code - * in order to pull out the values. - * - * @param varName the variable name to evaluate - * @return the evaluated value - */ - public static Object[] evalList(String varName) { - PythonVariables vars = new PythonVariables(); - vars.addList(varName); - exec("pass", vars); - return vars.getListValue(varName); - } - - - /** - * Evaluate a string based on the - * current variable name. - * This variable named needs to be present - * or defined earlier in python code - * in order to pull out the values. - * - * @param varName the variable name to evaluate - * @return the evaluated value - */ - public static Map evalDict(String varName) { - PythonVariables vars = new PythonVariables(); - vars.addDict(varName); - exec("pass", vars); - return vars.getDictValue(varName); - } - - - /** - * Evaluate a string based on the - * current variable name. - * This variable named needs to be present - * or defined earlier in python code - * in order to pull out the values. - * - * @param varName the variable name to evaluate - * @return the evaluated value - */ - public static NumpyArray evalNdArray(String varName) { - PythonVariables vars = new PythonVariables(); - vars.addNDArray(varName); - exec("pass", vars); - return vars.getNDArrayValue(varName); - } - - private static String outputVarName() { - return "_" + Thread.currentThread().getId() + "_" + currentInterpreter + "_out"; - } - - private static String outputCode(PythonVariables pyOutputs) { - if (pyOutputs == null){ - return ""; - } - - String outputCode = "import json\n"; - String outputFunctions; - try(BufferedInputStream bufferedInputStream = new BufferedInputStream(new ClassPathResource("pythonexec/serialize_array.py").getInputStream())) { - outputFunctions= IOUtils.toString(bufferedInputStream,Charset.defaultCharset()); - outputCode += outputFunctions; - outputCode += "\n"; - } catch (IOException e) { - throw new IllegalStateException("Unable to read python file pythonexec/serialize_arrays.py from classpath"); - } - - outputCode += outputVarName() + " = __serialize_dict({"; - String[] varNames = pyOutputs.getVariables(); - for (String varName: varNames) { - outputCode += "\"" + varName + "\": " + varName + ","; - } - - - if (varNames.length > 0) - outputCode = outputCode.substring(0, outputCode.length() - 1); - outputCode += "})"; - outputCode += "\nwith open('" + getTempFile() + "', 'w') as " + fileVarName + ":" + fileVarName + ".write(" + outputVarName() + ")"; - - - return outputCode; - - } - - private static String jArrayToPyString(Object[] array) { - String str = "["; - for (int i = 0; i < array.length; i++){ - Object obj = array[i]; - if (obj instanceof Object[]){ - str += jArrayToPyString((Object[])obj); - } - else if (obj instanceof String){ - str += "\"" + obj + "\""; - } - else{ - str += obj.toString().replace("\"", "\\\""); - } - if (i < array.length - 1){ - str += ","; - } - - } - str += "]"; - return str; - } - - private static String escapeStr(String str) { - if(str == null) - return null; - str = str.replace("\\", "\\\\"); - str = str.replace("\"\"\"", "\\\"\\\"\\\""); - return str; - } - - private static String getWrappedCode(String code) { - try(InputStream is = new ClassPathResource("pythonexec/pythonexec.py").getInputStream()) { - String base = IOUtils.toString(is, Charset.defaultCharset()); - StringBuffer indentedCode = new StringBuffer(); - for(String split : code.split("\n")) { - indentedCode.append(" " + split + "\n"); - - } - - String out = base.replace(" pass",indentedCode); - return out; - } catch (IOException e) { - throw new IllegalStateException("Unable to read python code!",e); - } - - } - - - - private static String getTempFile() { - String ret = "temp_" + Thread.currentThread().getId() + "_" + currentInterpreter + ".json"; - log.info(ret); - return ret; - } - - - private static String outputCodeForAllVariables() { - String outputCode = ""; - try(BufferedInputStream bufferedInputStream = new BufferedInputStream(new ClassPathResource("pythonexec/outputcode.py").getInputStream())) { - outputCode += IOUtils.toString(bufferedInputStream,Charset.defaultCharset()).replace("f2",fileVarName); - outputCode += "\n"; - } catch (IOException e) { - throw new IllegalStateException("Unable to read python file pythonexec/outputcode.py from classpath"); - } - - outputCode += String.format("vars = {key:value for (key,value) in locals().items() if not key.startswith('_') and key is not '%s' and key is not 'loc' and type(value) in (list, dict, str, int, float, bool, type(None))}\n",fileVarName); - outputCode += String.format("with open('" + getTempFile() + "', 'w') as %s:json.dump({",fileVarName); - outputCode +=String.format( "\"" + ALL_VARIABLES_KEY + "\"" + ": vars}, %s)\n",fileVarName); - return outputCode; - } - - - /*-----monkey patch for numpy-----*/ - private static List _getPatches() { - exec("import numpy as np"); - exec( "__overrides_path = np.core.overrides.__file__"); - exec("__random_path = np.random.__file__"); - - List patches = new ArrayList<>(); - - patches.add(new String[]{ - "pythonexec/patch0.py", - evalString("__overrides_path") - }); - patches.add(new String[]{ - "pythonexec/patch1.py", - evalString("__random_path") - }); - - return patches; - } - - private static void _applyPatch(String src, String dest){ - try(InputStream is = new ClassPathResource(src).getInputStream()) { - String patch = IOUtils.toString(is, Charset.defaultCharset()); - FileUtils.write(new File(dest), patch, "utf-8"); - } - catch(IOException e){ - log.warn("Error patching numpy: " + e); - } - } - - private static boolean _checkPatchApplied(String dest) { - try { - return FileUtils.readFileToString(new File(dest), "utf-8").startsWith("#patch"); - } catch (IOException e) { - return false; - } - } - - private static void applyPatches() { - // We patch numpy for partial support of multiple interpreters - for (String[] patch : _getPatches()){ - if (_checkPatchApplied(patch[1])){ - log.info("Patch already applied for " + patch[1]); - } - else{ - _applyPatch(patch[0], patch[1]); - log.info("Applied patch for " + patch[1]); - } - } - for (String[] patch: _getPatches()){ - if (!_checkPatchApplied(patch[1])){ - log.warn("Error patching numpy"); - } - } - } -} \ No newline at end of file +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonGIL.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonGIL.java new file mode 100644 index 000000000..d8afa2836 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonGIL.java @@ -0,0 +1,68 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + +import lombok.extern.slf4j.Slf4j; +import org.bytedeco.cpython.PyThreadState; + +import static org.bytedeco.cpython.global.python.*; +import static org.bytedeco.cpython.global.python.PyEval_RestoreThread; +import static org.bytedeco.cpython.global.python.PyEval_SaveThread; + + +@Slf4j +public class PythonGIL implements AutoCloseable { + private static PyThreadState mainThreadState; + + static { + log.debug("CPython: PyThreadState_Get()"); + mainThreadState = PyThreadState_Get(); + } + + private PythonGIL() { + acquire(); + } + + @Override + public void close() { + release(); + } + + public static PythonGIL lock() { + return new PythonGIL(); + } + + private static synchronized void acquire() { + log.debug("acquireGIL()"); + log.debug("CPython: PyEval_SaveThread()"); + mainThreadState = PyEval_SaveThread(); + log.debug("CPython: PyThreadState_New()"); + PyThreadState ts = PyThreadState_New(mainThreadState.interp()); + log.debug("CPython: PyEval_RestoreThread()"); + PyEval_RestoreThread(ts); + log.debug("CPython: PyThreadState_Swap()"); + PyThreadState_Swap(ts); + } + + private static synchronized void release() { + log.debug("CPython: PyEval_SaveThread()"); + PyEval_SaveThread(); + log.debug("CPython: PyEval_RestoreThread()"); + PyEval_RestoreThread(mainThreadState); + } +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonJob.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonJob.java new file mode 100644 index 000000000..c50c9bb9e --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonJob.java @@ -0,0 +1,171 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import javax.annotation.Nonnull; +import java.util.HashMap; +import java.util.Map; + + +@Data +@NoArgsConstructor +/** + * PythonJob is the right abstraction for executing multiple python scripts + * in a multi thread stateful environment. The setup-and-run mode allows your + * "setup" code (imports, model loading etc) to be executed only once. + */ +public class PythonJob { + + private String code; + private String name; + private String context; + private boolean setupRunMode; + private PythonObject runF; + + static { + new PythonExecutioner(); + } + + @Builder + /** + * @param name Name for the python job. + * @param code Python code. + * @param setupRunMode If true, the python code is expected to have two methods: setup(), which takes no arguments, + * and run() which takes some or no arguments. setup() method is executed once, + * and the run() method is called with the inputs(if any) per transaction, and is expected to return a dictionary + * mapping from output variable names (str) to output values. + * If false, the full script is run on each transaction and the output variables are obtained from the global namespace + * after execution. + */ + public PythonJob(@Nonnull String name, @Nonnull String code, boolean setupRunMode) throws Exception { + this.name = name; + this.code = code; + this.setupRunMode = setupRunMode; + context = "__job_" + name; + if (PythonContextManager.hasContext(context)) { + throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!"); + } + if (setupRunMode) setup(); + } + + + /** + * Clears all variables in current context and calls setup() + */ + public void clearState() throws Exception { + String context = this.context; + PythonContextManager.setContext("main"); + PythonContextManager.deleteContext(context); + this.context = context; + setup(); + } + + public void setup() throws Exception { + try (PythonGIL gil = PythonGIL.lock()) { + PythonContextManager.setContext(context); + PythonObject runF = PythonExecutioner.getVariable("run"); + if (runF.isNone() || !Python.callable(runF)) { + PythonExecutioner.exec(code); + runF = PythonExecutioner.getVariable("run"); + } + if (runF.isNone() || !Python.callable(runF)) { + throw new PythonException("run() method not found! " + + "If a PythonJob is created with 'setup and run' " + + "mode enabled, the associated python code is " + + "expected to contain a run() method " + + "(with or without arguments)."); + } + this.runF = runF; + PythonObject setupF = PythonExecutioner.getVariable("setup"); + if (!setupF.isNone()) { + setupF.call(); + } + } + } + + public void exec(PythonVariables inputs, PythonVariables outputs) throws Exception { + try (PythonGIL gil = PythonGIL.lock()) { + PythonContextManager.setContext(context); + if (!setupRunMode) { + PythonExecutioner.exec(code, inputs, outputs); + return; + } + PythonExecutioner.setVariables(inputs); + + PythonObject inspect = Python.importModule("inspect"); + PythonObject getfullargspec = inspect.attr("getfullargspec"); + PythonObject argspec = getfullargspec.call(runF); + PythonObject argsList = argspec.attr("args"); + PythonObject runargs = Python.dict(); + int argsCount = Python.len(argsList).toInt(); + for (int i = 0; i < argsCount; i++) { + PythonObject arg = argsList.get(i); + PythonObject val = Python.globals().get(arg); + if (val.isNone()) { + throw new PythonException("Input value not received for run() argument: " + arg.toString()); + } + runargs.set(arg, val); + } + PythonObject outDict = runF.callWithKwargs(runargs); + Python.globals().attr("update").call(outDict); + + PythonExecutioner.getVariables(outputs); + inspect.del(); + getfullargspec.del(); + argspec.del(); + runargs.del(); + } + } + + public PythonVariables execAndReturnAllVariables(PythonVariables inputs) throws Exception { + try (PythonGIL gil = PythonGIL.lock()) { + PythonContextManager.setContext(context); + if (!setupRunMode) { + return PythonExecutioner.execAndReturnAllVariables(code, inputs); + } + PythonExecutioner.setVariables(inputs); + PythonObject inspect = Python.importModule("inspect"); + PythonObject getfullargspec = inspect.attr("getfullargspec"); + PythonObject argspec = getfullargspec.call(runF); + PythonObject argsList = argspec.attr("args"); + PythonObject runargs = Python.dict(); + int argsCount = Python.len(argsList).toInt(); + for (int i = 0; i < argsCount; i++) { + PythonObject arg = argsList.get(i); + PythonObject val = Python.globals().get(arg); + if (val.isNone()) { + throw new PythonException("Input value not received for run() argument: " + arg.toString()); + } + runargs.set(arg, val); + } + PythonObject outDict = runF.callWithKwargs(runargs); + Python.globals().attr("update").call(outDict); + inspect.del(); + getfullargspec.del(); + argspec.del(); + runargs.del(); + return PythonExecutioner.getAllVariables(); + } + } + + +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java new file mode 100644 index 000000000..f1d54168b --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonObject.java @@ -0,0 +1,554 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +package org.datavec.python; + + +import org.bytedeco.cpython.PyObject; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Pointer; +import org.json.JSONArray; +import org.json.JSONObject; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.nativeblas.NativeOpsHolder; + +import java.util.*; + +import static org.bytedeco.cpython.global.python.*; +import static org.bytedeco.cpython.global.python.PyObject_SetItem; + +/** + * Swift like python wrapper for J + * + * @author Fariz Rahman + */ + +public class PythonObject { + private PyObject nativePythonObject; + + static { + new PythonExecutioner(); + } + + private static Map _getNDArraySerializer() { + Map ndarraySerializer = new HashMap<>(); + PythonObject lambda = Python.eval( + "lambda x: " + + "{'address':" + + "x.__array_interface__['data'][0]," + + "'shape':x.shape,'strides':x.strides," + + "'dtype': str(x.dtype),'_is_numpy_array': True}" + + " if str(type(x))== \"\" else x"); + ndarraySerializer.put("default", + lambda); + return ndarraySerializer; + + } + + public PythonObject(PyObject pyObject) { + nativePythonObject = pyObject; + } + + public PythonObject(INDArray npArray) { + this(new NumpyArray(npArray)); + } + + public PythonObject(BytePointer bp){ + nativePythonObject = PyByteArray_FromStringAndSize(bp, bp.capacity()); + } + + public PythonObject(NumpyArray npArray) { + PyObject ctypes = PyImport_ImportModule("ctypes"); + PyObject np = PyImport_ImportModule("numpy"); + PyObject ctype; + switch (npArray.getDtype()) { + case DOUBLE: + ctype = PyObject_GetAttrString(ctypes, "c_double"); + break; + case FLOAT: + ctype = PyObject_GetAttrString(ctypes, "c_float"); + break; + case LONG: + ctype = PyObject_GetAttrString(ctypes, "c_int64"); + break; + case INT: + ctype = PyObject_GetAttrString(ctypes, "c_int32"); + break; + case SHORT: + ctype = PyObject_GetAttrString(ctypes, "c_int16"); + break; + case UINT16: + ctype = PyObject_GetAttrString(ctypes, "c_uint16"); + break; + case UINT32: + ctype = PyObject_GetAttrString(ctypes, "c_uint32"); + break; + case UINT64: + ctype = PyObject_GetAttrString(ctypes, "c_uint64"); + break; + case BOOL: + ctype = PyObject_GetAttrString(ctypes, "c_bool"); + break; + case BYTE: + ctype = PyObject_GetAttrString(ctypes, "c_byte"); + break; + case UBYTE: + ctype = PyObject_GetAttrString(ctypes, "c_ubyte"); + break; + default: + throw new RuntimeException("Unsupported dtype: " + npArray.getDtype()); + } + + PyObject ctypesPointer = PyObject_GetAttrString(ctypes, "POINTER"); + PyObject argsTuple = PyTuple_New(1); + PyTuple_SetItem(argsTuple, 0, ctype); + PyObject ptrType = PyObject_Call(ctypesPointer, argsTuple, null); + + PyObject cast = PyObject_GetAttrString(ctypes, "cast"); + PyObject address = PyLong_FromLong(npArray.getAddress()); + PyObject argsTuple2 = PyTuple_New(2); + PyTuple_SetItem(argsTuple2, 0, address); + PyTuple_SetItem(argsTuple2, 1, ptrType); + PyObject ptr = PyObject_Call(cast, argsTuple2, null); + PyObject shapeTuple = PyTuple_New(npArray.getShape().length); + for (int i = 0; i < npArray.getShape().length; i++) { + PyObject dim = PyLong_FromLong(npArray.getShape()[i]); + PyTuple_SetItem(shapeTuple, i, dim); + Py_DecRef(dim); + } + PyObject ctypesLib = PyObject_GetAttrString(np, "ctypeslib"); + PyObject asArray = PyObject_GetAttrString(ctypesLib, "as_array"); + PyObject argsTuple3 = PyTuple_New(2); + PyTuple_SetItem(argsTuple3, 0, ptr); + PyTuple_SetItem(argsTuple3, 1, shapeTuple); + nativePythonObject = PyObject_Call(asArray, argsTuple3, null); + + Py_DecRef(ctypesPointer); + Py_DecRef(ctypesLib); + Py_DecRef(argsTuple); + Py_DecRef(argsTuple2); + Py_DecRef(argsTuple3); + Py_DecRef(cast); + Py_DecRef(asArray); + + } + + /*---primitve constructors---*/ + public PyObject getNativePythonObject() { + return nativePythonObject; + } + + public PythonObject(String data) { + nativePythonObject = PyUnicode_FromString(data); + } + + public PythonObject(int data) { + nativePythonObject = PyLong_FromLong((long) data); + } + + public PythonObject(long data) { + nativePythonObject = PyLong_FromLong(data); + } + + public PythonObject(double data) { + nativePythonObject = PyFloat_FromDouble(data); + } + + public PythonObject(boolean data) { + nativePythonObject = PyBool_FromLong(data ? 1 : 0); + } + + private static PythonObject j2pyObject(Object item) { + if (item instanceof PythonObject) { + return (PythonObject) item; + } else if (item instanceof PyObject) { + return new PythonObject((PyObject) item); + } else if (item instanceof INDArray) { + return new PythonObject((INDArray) item); + } else if (item instanceof NumpyArray) { + return new PythonObject((NumpyArray) item); + } else if (item instanceof List) { + return new PythonObject((List) item); + } else if (item instanceof Object[]) { + return new PythonObject((Object[]) item); + } else if (item instanceof Map) { + return new PythonObject((Map) item); + } else if (item instanceof String) { + return new PythonObject((String) item); + } else if (item instanceof Double) { + return new PythonObject((Double) item); + } else if (item instanceof Float) { + return new PythonObject((Float) item); + } else if (item instanceof Long) { + return new PythonObject((Long) item); + } else if (item instanceof Integer) { + return new PythonObject((Integer) item); + } else if (item instanceof Boolean) { + return new PythonObject((Boolean) item); + } else if (item instanceof Pointer){ + return new PythonObject(new BytePointer((Pointer)item)); + } else { + throw new RuntimeException("Unsupported item in list: " + item); + } + } + + public PythonObject(Object[] data) { + PyObject pyList = PyList_New((long) data.length); + for (int i = 0; i < data.length; i++) { + PyList_SetItem(pyList, i, j2pyObject(data[i]).nativePythonObject); + } + nativePythonObject = pyList; + } + + public PythonObject(List data) { + PyObject pyList = PyList_New((long) data.size()); + for (int i = 0; i < data.size(); i++) { + PyList_SetItem(pyList, i, j2pyObject(data.get(i)).nativePythonObject); + } + nativePythonObject = pyList; + } + + public PythonObject(Map data) { + PyObject pyDict = PyDict_New(); + for (Object k : data.keySet()) { + PythonObject pyKey; + if (k instanceof PythonObject) { + pyKey = (PythonObject) k; + } else if (k instanceof String) { + pyKey = new PythonObject((String) k); + } else if (k instanceof Double) { + pyKey = new PythonObject((Double) k); + } else if (k instanceof Float) { + pyKey = new PythonObject((Float) k); + } else if (k instanceof Long) { + pyKey = new PythonObject((Long) k); + } else if (k instanceof Integer) { + pyKey = new PythonObject((Integer) k); + } else if (k instanceof Boolean) { + pyKey = new PythonObject((Boolean) k); + } else { + throw new RuntimeException("Unsupported key in map: " + k.getClass()); + } + Object v = data.get(k); + PythonObject pyVal; + if (v instanceof PythonObject) { + pyVal = (PythonObject) v; + } else if (v instanceof PyObject) { + pyVal = new PythonObject((PyObject) v); + } else if (v instanceof INDArray) { + pyVal = new PythonObject((INDArray) v); + } else if (v instanceof NumpyArray) { + pyVal = new PythonObject((NumpyArray) v); + } else if (v instanceof Map) { + pyVal = new PythonObject((Map) v); + } else if (v instanceof List) { + pyVal = new PythonObject((List) v); + } else if (v instanceof String) { + pyVal = new PythonObject((String) v); + } else if (v instanceof Double) { + pyVal = new PythonObject((Double) v); + } else if (v instanceof Float) { + pyVal = new PythonObject((Float) v); + } else if (v instanceof Long) { + pyVal = new PythonObject((Long) v); + } else if (v instanceof Integer) { + pyVal = new PythonObject((Integer) v); + } else if (v instanceof Boolean) { + pyVal = new PythonObject((Boolean) v); + } else { + throw new RuntimeException("Unsupported value in map: " + k.getClass()); + } + + PyDict_SetItem(pyDict, pyKey.nativePythonObject, pyVal.nativePythonObject); + + } + nativePythonObject = pyDict; + } + + + /*------*/ + + private static String pyObjectToString(PyObject pyObject) { + PyObject repr = PyObject_Str(pyObject); + PyObject str = PyUnicode_AsEncodedString(repr, "utf-8", "~E~"); + String jstr = PyBytes_AsString(str).getString(); + Py_DecRef(repr); + Py_DecRef(str); + return jstr; + } + + public String toString() { + return pyObjectToString(nativePythonObject); + } + + public double toDouble() { + return PyFloat_AsDouble(nativePythonObject); + } + + public float toFloat() { + return (float) PyFloat_AsDouble(nativePythonObject); + } + + public int toInt() { + return (int) PyLong_AsLong(nativePythonObject); + } + + public long toLong() { + return PyLong_AsLong(nativePythonObject); + } + + public boolean toBoolean() { + if (isNone()) return false; + return toInt() != 0; + } + + public NumpyArray toNumpy() { + PyObject arrInterface = PyObject_GetAttrString(nativePythonObject, "__array_interface__"); // borrowed reference; DO NOT Py_DecRef() ! + PyObject data = PyDict_GetItemString(arrInterface, "data"); + PyObject pyAddress = PyTuple_GetItem(data, 0); + long address = PyLong_AsLong(pyAddress); + PyObject pyDtype = PyObject_GetAttrString(nativePythonObject, "dtype"); + PyObject pyDtypeName = PyObject_GetAttrString(pyDtype, "name"); + String dtypeName = pyObjectToString(pyDtypeName); + Py_DecRef(pyDtype); + Py_DecRef(pyDtypeName); + PyObject shape = PyObject_GetAttrString(nativePythonObject, "shape"); + PyObject strides = PyObject_GetAttrString(nativePythonObject, "strides"); + int ndim = (int) PyObject_Size(shape); + long[] jshape = new long[ndim]; + long[] jstrides = new long[ndim]; + for (int i = 0; i < ndim; i++) { + jshape[i] = PyLong_AsLong(PyTuple_GetItem(shape, i)); + jstrides[i] = PyLong_AsLong(PyTuple_GetItem(strides, i)); + } + Py_DecRef(shape); + Py_DecRef(strides); + DataType dtype; + if (dtypeName.equals("float64")) { + dtype = DataType.DOUBLE; + } else if (dtypeName.equals("float32")) { + dtype = DataType.FLOAT; + } else if (dtypeName.equals("int16")) { + dtype = DataType.SHORT; + } else if (dtypeName.equals("int32")) { + dtype = DataType.INT; + } else if (dtypeName.equals("int64")) { + dtype = DataType.LONG; + } else { + throw new RuntimeException("Unsupported array type " + dtypeName + "."); + } + return new NumpyArray(address, jshape, jstrides, dtype); + + } + + public PythonObject attr(String attr) { + + return new PythonObject(PyObject_GetAttrString(nativePythonObject, attr)); + } + + public PythonObject call(Object... args) { + if (args.length > 0 && args[args.length - 1] instanceof Map) { + List args2 = new ArrayList<>(); + for (int i = 0; i < args.length - 1; i++) { + args2.add(args[i]); + } + return call(args2, (Map) args[args.length - 1]); + } + if (args.length == 0) { + return new PythonObject(PyObject_CallObject(nativePythonObject, null)); + } + PyObject tuple = PyTuple_New(args.length); // leaky; tuple may contain borrowed references, so can not be de-allocated. + for (int i = 0; i < args.length; i++) { + PyTuple_SetItem(tuple, i, j2pyObject(args[i]).nativePythonObject); + } + PythonObject ret = new PythonObject(PyObject_Call(nativePythonObject, tuple, null)); + return ret; + } + + public PythonObject callWithArgs(PythonObject args) { + PyObject tuple = PyList_AsTuple(args.nativePythonObject); + return new PythonObject(PyObject_Call(nativePythonObject, tuple, null)); + } + + public PythonObject callWithKwargs(PythonObject kwargs) { + PyObject tuple = PyTuple_New(0); + return new PythonObject(PyObject_Call(nativePythonObject, tuple, kwargs.nativePythonObject)); + } + + public PythonObject callWithArgsAndKwargs(PythonObject args, PythonObject kwargs) { + PyObject tuple = PyList_AsTuple(args.nativePythonObject); + PyObject dict = kwargs.nativePythonObject; + return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict)); + } + + public PythonObject call(Map kwargs) { + PyObject dict = new PythonObject(kwargs).nativePythonObject; + PyObject tuple = PyTuple_New(0); + return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict)); + } + + public PythonObject call(List args) { + PyObject tuple = PyList_AsTuple(new PythonObject(args).nativePythonObject); + return new PythonObject(PyObject_Call(nativePythonObject, tuple, null)); + } + + public PythonObject call(List args, Map kwargs) { + PyObject tuple = PyList_AsTuple(new PythonObject(args).nativePythonObject); + PyObject dict = new PythonObject(kwargs).nativePythonObject; + return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict)); + } + + private PythonObject get(PyObject key) { + return new PythonObject( + PyObject_GetItem(nativePythonObject, key) + ); + } + + public PythonObject get(PythonObject key) { + return get(key.nativePythonObject); + } + + + public PythonObject get(int key) { + return get(PyLong_FromLong((long) key)); + } + + public PythonObject get(long key) { + return new PythonObject( + PyObject_GetItem(nativePythonObject, PyLong_FromLong(key)) + + ); + } + + public PythonObject get(double key) { + return new PythonObject( + PyObject_GetItem(nativePythonObject, PyFloat_FromDouble(key)) + + ); + } + + public PythonObject get(String key) { + return get(new PythonObject(key)); + } + + public void set(PythonObject key, PythonObject value) { + PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject); + } + + public void del() { + Py_DecRef(nativePythonObject); + nativePythonObject = null; + } + + public JSONArray toJSONArray() throws PythonException { + PythonObject json = Python.importModule("json"); + PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer()); + String jsonString = serialized.toString(); + return new JSONArray(jsonString); + + } + + public JSONObject toJSONObject() throws PythonException { + PythonObject json = Python.importModule("json"); + PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer()); + String jsonString = serialized.toString(); + return new JSONObject(jsonString); + } + + public List toList() throws PythonException{ + List list = new ArrayList(); + int n = Python.len(this).toInt(); + for (int i = 0; i < n; i++) { + PythonObject o = get(i); + if (Python.isinstance(o, Python.strType())) { + list.add(o.toString()); + } else if (Python.isinstance(o, Python.intType())) { + list.add(o.toLong()); + } else if (Python.isinstance(o, Python.floatType())) { + list.add(o.toDouble()); + } else if (Python.isinstance(o, Python.boolType())) { + list.add(o); + } else if (Python.isinstance(o, Python.listType(), Python.tupleType())) { + list.add(o.toList()); + } else if (Python.isinstance(o, Python.importModule("numpy").attr("ndarray"))) { + list.add(o.toNumpy().getNd4jArray()); + } else if (Python.isinstance(o, Python.dictType())) { + list.add(o.toMap()); + } else { + throw new RuntimeException("Error while converting python" + + " list to java List: Unable to serialize python " + + "object of type " + Python.type(this).toString()); + } + } + + return list; + } + + public Map toMap() throws PythonException{ + Map map = new HashMap(); + List keys = Python.list(attr("keys").call()).toList(); + List values = Python.list(attr("values").call()).toList(); + for (int i = 0; i < keys.size(); i++) { + map.put(keys.get(i), values.get(i)); + } + return map; + } + + public BytePointer toBytePointer() throws PythonException{ + if (Python.isinstance(this, Python.bytesType())){ + PyObject byteArray = PyByteArray_FromObject(nativePythonObject); + return PyByteArray_AsString(byteArray); + + } + else if (Python.isinstance(this, Python.bytearrayType())){ + return PyByteArray_AsString(nativePythonObject); + } + else{ + PyObject ctypes = PyImport_ImportModule("ctypes"); + PyObject cArrType = PyObject_GetAttrString(ctypes, "Array"); + if (PyObject_IsInstance(nativePythonObject, cArrType) != 0){ + PyObject cVoidP = PyObject_GetAttrString(ctypes, "c_void_p"); + PyObject cast = PyObject_GetAttrString(ctypes, "cast"); + PyObject argsTuple = PyTuple_New(2); + PyTuple_SetItem(argsTuple, 0, nativePythonObject); + PyTuple_SetItem(argsTuple, 1, cVoidP); + PyObject voidPtr = PyObject_Call(cast, argsTuple, null); + PyObject pyAddress = PyObject_GetAttrString(voidPtr, "value"); + long address = PyLong_AsLong(pyAddress); + long size = PyObject_Size(nativePythonObject); + Py_DecRef(ctypes); + Py_DecRef(cArrType); + Py_DecRef(argsTuple); + Py_DecRef(voidPtr); + Py_DecRef(pyAddress); + Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address); + ptr = ptr.limit(size); + ptr = ptr.capacity(size); + return new BytePointer(ptr); + } + else{ + throw new PythonException("Expected bytes, bytearray or ctypesArray. Received " + Python.type(this).toString()); + } + + } + } + public boolean isNone() { + return nativePythonObject == null; + } + +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java index 8f2460035..ab67adc46 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonTransform.java @@ -24,10 +24,13 @@ import org.datavec.api.transform.ColumnType; import org.datavec.api.transform.Transform; import org.datavec.api.transform.schema.Schema; import org.datavec.api.writable.*; +import org.json.JSONPropertyIgnore; import org.nd4j.base.Preconditions; import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.shade.jackson.core.JsonProcessingException; +import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; import java.io.IOException; import java.io.InputStream; @@ -52,12 +55,13 @@ public class PythonTransform implements Transform { private String code; private PythonVariables inputs; private PythonVariables outputs; - private String name = UUID.randomUUID().toString(); + private String name = UUID.randomUUID().toString(); private Schema inputSchema; private Schema outputSchema; private String outputDict; private boolean returnAllVariables; private boolean setupAndRun = false; + private PythonJob pythonJob; @Builder @@ -70,71 +74,70 @@ public class PythonTransform implements Transform { String outputDict, boolean returnAllInputs, boolean setupAndRun) { - Preconditions.checkNotNull(code,"No code found to run!"); + Preconditions.checkNotNull(code, "No code found to run!"); this.code = code; this.returnAllVariables = returnAllInputs; this.setupAndRun = setupAndRun; - if(inputs != null) + if (inputs != null) this.inputs = inputs; - if(outputs != null) + if (outputs != null) this.outputs = outputs; - - if(name != null) + if (name != null) this.name = name; if (outputDict != null) { this.outputDict = outputDict; this.outputs = new PythonVariables(); this.outputs.addDict(outputDict); - - String helpers; - try(InputStream is = new ClassPathResource("pythonexec/serialize_array.py").getInputStream()) { - helpers = IOUtils.toString(is, Charset.defaultCharset()); - - }catch (IOException e){ - throw new RuntimeException("Error reading python code"); - } - this.code += "\n\n" + helpers; - this.code += "\n" + outputDict + " = __recursive_serialize_dict(" + outputDict + ")"; } try { - if(inputSchema != null) { + if (inputSchema != null) { this.inputSchema = inputSchema; - if(inputs == null || inputs.isEmpty()) { + if (inputs == null || inputs.isEmpty()) { this.inputs = schemaToPythonVariables(inputSchema); } } - if(outputSchema != null) { + if (outputSchema != null) { this.outputSchema = outputSchema; - if(outputs == null || outputs.isEmpty()) { + if (outputs == null || outputs.isEmpty()) { this.outputs = schemaToPythonVariables(outputSchema); } } - }catch(Exception e) { + } catch (Exception e) { throw new IllegalStateException(e); } + try{ + pythonJob = PythonJob.builder() + .name("a" + UUID.randomUUID().toString().replace("-", "_")) + .code(code) + .setupRunMode(setupAndRun) + .build(); + } + catch(Exception e){ + throw new IllegalStateException("Error creating python job: " + e); + } } @Override public void setInputSchema(Schema inputSchema) { - Preconditions.checkNotNull(inputSchema,"No input schema found!"); + Preconditions.checkNotNull(inputSchema, "No input schema found!"); this.inputSchema = inputSchema; - try{ + try { inputs = schemaToPythonVariables(inputSchema); - }catch (Exception e){ + } catch (Exception e) { throw new RuntimeException(e); } - if (outputSchema == null && outputDict == null){ + if (outputSchema == null && outputDict == null) { outputSchema = inputSchema; } } @Override - public Schema getInputSchema(){ + public Schema getInputSchema() { return inputSchema; } @@ -158,67 +161,51 @@ public class PythonTransform implements Transform { } - - @Override public List map(List writables) { PythonVariables pyInputs = getPyInputsFromWritables(writables); - Preconditions.checkNotNull(pyInputs,"Inputs must not be null!"); - - - try{ + Preconditions.checkNotNull(pyInputs, "Inputs must not be null!"); + try { if (returnAllVariables) { - if (setupAndRun){ - return getWritablesFromPyOutputs(PythonExecutioner.execWithSetupRunAndReturnAllVariables(code, pyInputs)); - } - return getWritablesFromPyOutputs(PythonExecutioner.execAndReturnAllVariables(code, pyInputs)); + return getWritablesFromPyOutputs(pythonJob.execAndReturnAllVariables(pyInputs)); } if (outputDict != null) { - if (setupAndRun) { - PythonExecutioner.execWithSetupAndRun(this, pyInputs); - }else{ - PythonExecutioner.exec(this, pyInputs); - } + pythonJob.exec(pyInputs, outputs); PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict); return getWritablesFromPyOutputs(out); - } - else { - if (setupAndRun) { - PythonExecutioner.execWithSetupAndRun(code, pyInputs, outputs); - }else{ - PythonExecutioner.exec(code, pyInputs, outputs); - } + } else { + pythonJob.exec(pyInputs, outputs); return getWritablesFromPyOutputs(outputs); } - } - catch (Exception e){ + } catch (Exception e) { throw new RuntimeException(e); } } @Override - public String[] outputColumnNames(){ + public String[] outputColumnNames() { return outputs.getVariables(); } @Override - public String outputColumnName(){ + public String outputColumnName() { return outputColumnNames()[0]; } + @Override - public String[] columnNames(){ + public String[] columnNames() { return outputs.getVariables(); } @Override - public String columnName(){ + public String columnName() { return columnNames()[0]; } - public Schema transform(Schema inputSchema){ + public Schema transform(Schema inputSchema) { return outputSchema; } @@ -226,33 +213,33 @@ public class PythonTransform implements Transform { private PythonVariables getPyInputsFromWritables(List writables) { PythonVariables ret = new PythonVariables(); - for (String name: inputs.getVariables()) { + for (String name : inputs.getVariables()) { int colIdx = inputSchema.getIndexOfColumn(name); Writable w = writables.get(colIdx); - PythonVariables.Type pyType = inputs.getType(name); - switch (pyType){ + PythonType pyType = inputs.getType(name); + switch (pyType.getName()) { case INT: - if (w instanceof LongWritable){ - ret.addInt(name, ((LongWritable)w).get()); + if (w instanceof LongWritable) { + ret.addInt(name, ((LongWritable) w).get()); + } else { + ret.addInt(name, ((IntWritable) w).get()); } - else{ - ret.addInt(name, ((IntWritable)w).get()); - } - break; case FLOAT: if (w instanceof DoubleWritable) { - ret.addFloat(name, ((DoubleWritable)w).get()); - } - else{ - ret.addFloat(name, ((FloatWritable)w).get()); + ret.addFloat(name, ((DoubleWritable) w).get()); + } else { + ret.addFloat(name, ((FloatWritable) w).get()); } break; case STR: ret.addStr(name, w.toString()); break; case NDARRAY: - ret.addNDArray(name,((NDArrayWritable)w).get()); + ret.addNDArray(name, ((NDArrayWritable) w).get()); + break; + case BOOL: + ret.addBool(name, ((BooleanWritable) w).get()); break; default: throw new RuntimeException("Unsupported input type:" + pyType); @@ -269,8 +256,8 @@ public class PythonTransform implements Transform { Schema.Builder schemaBuilder = new Schema.Builder(); for (int i = 0; i < varNames.length; i++) { String name = varNames[i]; - PythonVariables.Type pyType = pyOuts.getType(name); - switch (pyType){ + PythonType pyType = pyOuts.getType(name); + switch (pyType.getName()) { case INT: schemaBuilder.addColumnLong(name); break; @@ -283,11 +270,14 @@ public class PythonTransform implements Transform { schemaBuilder.addColumnString(name); break; case NDARRAY: - NumpyArray arr = pyOuts.getNDArrayValue(name); - schemaBuilder.addColumnNDArray(name, arr.getShape()); + INDArray arr = pyOuts.getNDArrayValue(name); + schemaBuilder.addColumnNDArray(name, arr.shape()); + break; + case BOOL: + schemaBuilder.addColumnBoolean(name); break; default: - throw new IllegalStateException("Unable to support type " + pyType.name()); + throw new IllegalStateException("Unable to support type " + pyType.getName()); } } this.outputSchema = schemaBuilder.build(); @@ -295,9 +285,9 @@ public class PythonTransform implements Transform { for (int i = 0; i < varNames.length; i++) { String name = varNames[i]; - PythonVariables.Type pyType = pyOuts.getType(name); + PythonType pyType = pyOuts.getType(name); - switch (pyType){ + switch (pyType.getName()) { case INT: out.add(new LongWritable(pyOuts.getIntValue(name))); break; @@ -308,14 +298,14 @@ public class PythonTransform implements Transform { out.add(new Text(pyOuts.getStrValue(name))); break; case NDARRAY: - NumpyArray arr = pyOuts.getNDArrayValue(name); - out.add(new NDArrayWritable(arr.getNd4jArray())); + INDArray arr = pyOuts.getNDArrayValue(name); + out.add(new NDArrayWritable(arr)); break; case DICT: Map dictValue = pyOuts.getDictValue(name); Map noNullValues = new java.util.HashMap<>(); - for(Map.Entry entry : dictValue.entrySet()) { - if(entry.getValue() != org.json.JSONObject.NULL) { + for (Map.Entry entry : dictValue.entrySet()) { + if (entry.getValue() != org.json.JSONObject.NULL) { noNullValues.put(entry.getKey(), entry.getValue()); } } @@ -327,21 +317,22 @@ public class PythonTransform implements Transform { } break; case LIST: - Object[] listValue = pyOuts.getListValue(name); + Object[] listValue = pyOuts.getListValue(name).toArray(); try { out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue))); } catch (JsonProcessingException e) { throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!"); } break; + case BOOL: + out.add(new BooleanWritable(pyOuts.getBooleanValue(name))); + break; default: - throw new IllegalStateException("Unable to support type " + pyType.name()); + throw new IllegalStateException("Unable to support type " + pyType.getName()); } } return out; } - - } \ No newline at end of file diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonType.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonType.java new file mode 100644 index 000000000..60603a8e3 --- /dev/null +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonType.java @@ -0,0 +1,238 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Pointer; +import org.nd4j.linalg.api.ndarray.INDArray; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static org.datavec.python.Python.importModule; + + +/** + * + * @param Corresponding Java type for the Python type + */ +public abstract class PythonType { + + public abstract T toJava(PythonObject pythonObject) throws PythonException; + private final TypeName typeName; + + enum TypeName{ + STR, + INT, + FLOAT, + BOOL, + LIST, + DICT, + NDARRAY, + BYTES + } + private PythonType(TypeName typeName){ + this.typeName = typeName; + } + public TypeName getName(){return typeName;} + public String toString(){ + return getName().name(); + } + public static PythonType valueOf(String typeName) throws PythonException{ + try{ + typeName.valueOf(typeName); + } catch (IllegalArgumentException iae){ + throw new PythonException("Invalid python type: " + typeName, iae); + } + try{ + return (PythonType)PythonType.class.getField(typeName).get(null); // shouldn't fail + } catch (Exception e){ + throw new RuntimeException(e); + } + + } + public static PythonType valueOf(TypeName typeName){ + try{ + return valueOf(typeName.name()); // shouldn't fail + }catch (PythonException pe){ + throw new RuntimeException(pe); + } + } + + /** + * Since multiple java types can map to the same python type, + * this method "normalizes" all supported incoming objects to T + * + * @param object object to be converted to type T + * @return + */ + public T convert(Object object) throws PythonException { + return (T) object; + } + + public static final PythonType STR = new PythonType(TypeName.STR) { + @Override + public String toJava(PythonObject pythonObject) throws PythonException { + if (!Python.isinstance(pythonObject, Python.strType())) { + throw new PythonException("Expected variable to be str, but was " + Python.type(pythonObject)); + } + return pythonObject.toString(); + } + + @Override + public String convert(Object object) { + return object.toString(); + } + }; + + public static final PythonType INT = new PythonType(TypeName.INT) { + @Override + public Long toJava(PythonObject pythonObject) throws PythonException { + if (!Python.isinstance(pythonObject, Python.intType())) { + throw new PythonException("Expected variable to be int, but was " + Python.type(pythonObject)); + } + return pythonObject.toLong(); + } + + @Override + public Long convert(Object object) throws PythonException { + if (object instanceof Number) { + return ((Number) object).longValue(); + } + throw new PythonException("Unable to cast " + object + " to Long."); + } + }; + + public static final PythonType FLOAT = new PythonType(TypeName.FLOAT) { + @Override + public Double toJava(PythonObject pythonObject) throws PythonException { + if (!Python.isinstance(pythonObject, Python.floatType())) { + throw new PythonException("Expected variable to be float, but was " + Python.type(pythonObject)); + } + return pythonObject.toDouble(); + } + + @Override + public Double convert(Object object) throws PythonException { + if (object instanceof Number) { + return ((Number) object).doubleValue(); + } + throw new PythonException("Unable to cast " + object + " to Double."); + } + }; + + public static final PythonType BOOL = new PythonType(TypeName.BOOL) { + @Override + public Boolean toJava(PythonObject pythonObject) throws PythonException { + if (!Python.isinstance(pythonObject, Python.boolType())) { + throw new PythonException("Expected variable to be bool, but was " + Python.type(pythonObject)); + } + return pythonObject.toBoolean(); + } + + @Override + public Boolean convert(Object object) throws PythonException { + if (object instanceof Number) { + return ((Number) object).intValue() != 0; + } else if (object instanceof Boolean) { + return (Boolean) object; + } + throw new PythonException("Unable to cast " + object + " to Boolean."); + } + }; + + public static final PythonType LIST = new PythonType(TypeName.LIST) { + @Override + public List toJava(PythonObject pythonObject) throws PythonException { + if (!Python.isinstance(pythonObject, Python.listType())) { + throw new PythonException("Expected variable to be list, but was " + Python.type(pythonObject)); + } + return pythonObject.toList(); + } + + @Override + public List convert(Object object) throws PythonException { + if (object instanceof java.util.List) { + return (List) object; + } else if (object instanceof org.json.JSONArray) { + org.json.JSONArray jsonArray = (org.json.JSONArray) object; + return jsonArray.toList(); + + } else if (object instanceof Object[]) { + return Arrays.asList((Object[]) object); + } + throw new PythonException("Unable to cast " + object + " to List."); + } + }; + + public static final PythonType DICT = new PythonType(TypeName.DICT) { + @Override + public Map toJava(PythonObject pythonObject) throws PythonException { + if (!Python.isinstance(pythonObject, Python.dictType())) { + throw new PythonException("Expected variable to be dict, but was " + Python.type(pythonObject)); + } + return pythonObject.toMap(); + } + + @Override + public Map convert(Object object) throws PythonException { + if (object instanceof Map) { + return (Map) object; + } + throw new PythonException("Unable to cast " + object + " to Map."); + } + }; + + public static final PythonType NDARRAY = new PythonType(TypeName.NDARRAY) { + @Override + public INDArray toJava(PythonObject pythonObject) throws PythonException { + PythonObject np = importModule("numpy"); + if (!Python.isinstance(pythonObject, np.attr("ndarray"), np.attr("generic"))) { + throw new PythonException("Expected variable to be numpy.ndarray, but was " + Python.type(pythonObject)); + } + return pythonObject.toNumpy().getNd4jArray(); + } + + @Override + public INDArray convert(Object object) throws PythonException { + if (object instanceof INDArray) { + return (INDArray) object; + } else if (object instanceof NumpyArray) { + return ((NumpyArray) object).getNd4jArray(); + } + throw new PythonException("Unable to cast " + object + " to INDArray."); + } + }; + + public static final PythonType BYTES = new PythonType(TypeName.BYTES) { + @Override + public BytePointer toJava(PythonObject pythonObject) throws PythonException { + return pythonObject.toBytePointer(); + } + + @Override + public BytePointer convert(Object object) throws PythonException { + if (object instanceof BytePointer) { + return (BytePointer) object; + } else if (object instanceof Pointer) { + return new BytePointer((Pointer) object); + } + throw new PythonException("Unable to cast " + object + " to BytePointer."); + } + }; +} diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java index a8334cbc5..c510d8130 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonUtils.java @@ -24,28 +24,30 @@ public class PythonUtils { * Create a {@link Schema} * from {@link PythonVariables}. * Types are mapped to types of the same name. + * * @param input the input {@link PythonVariables} * @return the output {@link Schema} */ public static Schema fromPythonVariables(PythonVariables input) { Schema.Builder schemaBuilder = new Schema.Builder(); - Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0,"Input must have variables. Found none."); - for(Map.Entry entry : input.getVars().entrySet()) { - switch(entry.getValue()) { + Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0, "Input must have variables. Found none."); + for (String varName: input.getVariables()) { + + switch (input.getType(varName).getName()) { case INT: - schemaBuilder.addColumnInteger(entry.getKey()); + schemaBuilder.addColumnInteger(varName); break; case STR: - schemaBuilder.addColumnString(entry.getKey()); + schemaBuilder.addColumnString(varName); break; case FLOAT: - schemaBuilder.addColumnFloat(entry.getKey()); + schemaBuilder.addColumnFloat(varName); break; case NDARRAY: - schemaBuilder.addColumnNDArray(entry.getKey(),null); + schemaBuilder.addColumnNDArray(varName, null); break; case BOOL: - schemaBuilder.addColumn(new BooleanMetaData(entry.getKey())); + schemaBuilder.addColumn(new BooleanMetaData(varName)); } } @@ -56,34 +58,36 @@ public class PythonUtils { * Create a {@link Schema} from an input * {@link PythonVariables} * Types are mapped to types of the same name + * * @param input the input schema * @return the output python variables. */ public static PythonVariables fromSchema(Schema input) { PythonVariables ret = new PythonVariables(); - for(int i = 0; i < input.numColumns(); i++) { + for (int i = 0; i < input.numColumns(); i++) { String currColumnName = input.getName(i); ColumnType columnType = input.getType(i); - switch(columnType) { + switch (columnType) { case NDArray: - ret.add(currColumnName, PythonVariables.Type.NDARRAY); + ret.add(currColumnName, PythonType.NDARRAY); break; case Boolean: - ret.add(currColumnName, PythonVariables.Type.BOOL); + ret.add(currColumnName, PythonType.BOOL); break; case Categorical: case String: - ret.add(currColumnName, PythonVariables.Type.STR); + ret.add(currColumnName, PythonType.STR); break; case Double: case Float: - ret.add(currColumnName, PythonVariables.Type.FLOAT); + ret.add(currColumnName, PythonType.FLOAT); break; case Integer: case Long: - ret.add(currColumnName, PythonVariables.Type.INT); + ret.add(currColumnName, PythonType.INT); break; case Bytes: + ret.add(currColumnName, PythonType.BYTES); break; case Time: throw new UnsupportedOperationException("Unable to process dates with python yet."); @@ -92,9 +96,11 @@ public class PythonUtils { return ret; } + /** * Convert a {@link Schema} * to {@link PythonVariables} + * * @param schema the input schema * @return the output {@link PythonVariables} where each * name in the map is associated with a column name in the schema. @@ -107,7 +113,7 @@ public class PythonUtils { for (int i = 0; i < numCols; i++) { String colName = schema.getName(i); ColumnType colType = schema.getType(i); - switch (colType){ + switch (colType) { case Long: case Integer: pyVars.addInt(colName); @@ -122,6 +128,9 @@ public class PythonUtils { case NDArray: pyVars.addNDArray(colName); break; + case Boolean: + pyVars.addBool(colName); + break; default: throw new Exception("Unsupported python input type: " + colType.toString()); } @@ -131,117 +140,104 @@ public class PythonUtils { } - public static NumpyArray mapToNumpyArray(Map map){ - String dtypeName = (String)map.get("dtype"); + public static NumpyArray mapToNumpyArray(Map map) { + String dtypeName = (String) map.get("dtype"); DataType dtype; - if (dtypeName.equals("float64")){ + if (dtypeName.equals("float64")) { dtype = DataType.DOUBLE; - } - else if (dtypeName.equals("float32")){ + } else if (dtypeName.equals("float32")) { dtype = DataType.FLOAT; - } - else if (dtypeName.equals("int16")){ + } else if (dtypeName.equals("int16")) { dtype = DataType.SHORT; - } - else if (dtypeName.equals("int32")){ + } else if (dtypeName.equals("int32")) { dtype = DataType.INT; - } - else if (dtypeName.equals("int64")){ + } else if (dtypeName.equals("int64")) { dtype = DataType.LONG; - } - else{ + } else { throw new RuntimeException("Unsupported array type " + dtypeName + "."); } - List shapeList = (List)map.get("shape"); + List shapeList = (List) map.get("shape"); long[] shape = new long[shapeList.size()]; for (int i = 0; i < shape.length; i++) { - shape[i] = (Long)shapeList.get(i); + shape[i] = (Long) shapeList.get(i); } - List strideList = (List)map.get("shape"); + List strideList = (List) map.get("shape"); long[] stride = new long[strideList.size()]; for (int i = 0; i < stride.length; i++) { - stride[i] = (Long)strideList.get(i); + stride[i] = (Long) strideList.get(i); } - long address = (Long)map.get("address"); - NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype); + long address = (Long) map.get("address"); + NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true); return numpyArray; } - public static PythonVariables expandInnerDict(PythonVariables pyvars, String key){ + public static PythonVariables expandInnerDict(PythonVariables pyvars, String key) { Map dict = pyvars.getDictValue(key); - String[] keys = (String[])dict.keySet().toArray(new String[dict.keySet().size()]); + String[] keys = (String[]) dict.keySet().toArray(new String[dict.keySet().size()]); PythonVariables pyvars2 = new PythonVariables(); - for (String subkey: keys){ + for (String subkey : keys) { Object value = dict.get(subkey); - if (value instanceof Map){ - Map map = (Map)value; - if (map.containsKey("_is_numpy_array")){ + if (value instanceof Map) { + Map map = (Map) value; + if (map.containsKey("_is_numpy_array")) { pyvars2.addNDArray(subkey, mapToNumpyArray(map)); - } - else{ - pyvars2.addDict(subkey, (Map)value); + } else { + pyvars2.addDict(subkey, (Map) value); } - } - else if (value instanceof List){ + } else if (value instanceof List) { pyvars2.addList(subkey, ((List) value).toArray()); - } - else if (value instanceof String){ - System.out.println((String)value); + } else if (value instanceof String) { + System.out.println((String) value); pyvars2.addStr(subkey, (String) value); - } - else if (value instanceof Integer || value instanceof Long) { + } else if (value instanceof Integer || value instanceof Long) { Number number = (Number) value; pyvars2.addInt(subkey, number.intValue()); - } - else if (value instanceof Float || value instanceof Double) { + } else if (value instanceof Float || value instanceof Double) { Number number = (Number) value; pyvars2.addFloat(subkey, number.doubleValue()); - } - else if (value instanceof NumpyArray){ - pyvars2.addNDArray(subkey, (NumpyArray)value); - } - else if (value == null){ + } else if (value instanceof NumpyArray) { + pyvars2.addNDArray(subkey, (NumpyArray) value); + } else if (value == null) { pyvars2.addStr(subkey, "None"); // FixMe - } - else{ + } else { throw new RuntimeException("Unsupported type!" + value); } } return pyvars2; } - public static long[] jsonArrayToLongArray(JSONArray jsonArray){ + public static long[] jsonArrayToLongArray(JSONArray jsonArray) { long[] longs = new long[jsonArray.length()]; - for (int i=0; i toMap(JSONObject jsonobj) { + public static Map toMap(JSONObject jsonobj) { Map map = new HashMap<>(); - String[] keys = (String[])jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]); - for (String key: keys){ + String[] keys = (String[]) jsonobj.keySet().toArray(new String[jsonobj.keySet().size()]); + for (String key : keys) { Object value = jsonobj.get(key); if (value instanceof JSONArray) { value = toList((JSONArray) value); } else if (value instanceof JSONObject) { - JSONObject jsonobj2 = (JSONObject)value; - if (jsonobj2.has("_is_numpy_array")){ + JSONObject jsonobj2 = (JSONObject) value; + if (jsonobj2.has("_is_numpy_array")) { value = jsonToNumpyArray(jsonobj2); - } - else{ + } else { value = toMap(jsonobj2); } } map.put(key, value); - } return map; + } + return map; } @@ -265,40 +261,35 @@ public class PythonUtils { } - private static NumpyArray jsonToNumpyArray(JSONObject map){ - String dtypeName = (String)map.get("dtype"); + private static NumpyArray jsonToNumpyArray(JSONObject map) { + String dtypeName = (String) map.get("dtype"); DataType dtype; - if (dtypeName.equals("float64")){ + if (dtypeName.equals("float64")) { dtype = DataType.DOUBLE; - } - else if (dtypeName.equals("float32")){ + } else if (dtypeName.equals("float32")) { dtype = DataType.FLOAT; - } - else if (dtypeName.equals("int16")){ + } else if (dtypeName.equals("int16")) { dtype = DataType.SHORT; - } - else if (dtypeName.equals("int32")){ + } else if (dtypeName.equals("int32")) { dtype = DataType.INT; - } - else if (dtypeName.equals("int64")){ + } else if (dtypeName.equals("int64")) { dtype = DataType.LONG; - } - else{ + } else { throw new RuntimeException("Unsupported array type " + dtypeName + "."); } - List shapeList = (List)map.get("shape"); + List shapeList = map.getJSONArray("shape").toList(); long[] shape = new long[shapeList.size()]; for (int i = 0; i < shape.length; i++) { - shape[i] = (Long)shapeList.get(i); + shape[i] = ((Number) shapeList.get(i)).longValue(); } - List strideList = (List)map.get("shape"); + List strideList = map.getJSONArray("shape").toList(); long[] stride = new long[strideList.size()]; for (int i = 0; i < stride.length; i++) { - stride[i] = (Long)strideList.get(i); + stride[i] = ((Number) strideList.get(i)).longValue(); } - long address = (Long)map.get("address"); - NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype); + long address = ((Number) map.get("address")).longValue(); + NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true); return numpyArray; } diff --git a/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java b/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java index 4d04f1d87..9d8b5c2a1 100644 --- a/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java +++ b/datavec/datavec-python/src/main/java/org/datavec/python/PythonVariables.java @@ -17,13 +17,19 @@ package org.datavec.python; import lombok.Data; +import org.bytedeco.javacpp.BytePointer; +import org.bytedeco.javacpp.Pointer; import org.json.JSONObject; import org.json.JSONArray; import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.nativeblas.NativeOpsHolder; import java.io.Serializable; +import java.nio.ByteBuffer; import java.util.*; + + /** * Holds python variable names, types and values. * Also handles mapping from java types to python types. @@ -33,41 +39,31 @@ import java.util.*; @lombok.Data public class PythonVariables implements java.io.Serializable { - - public enum Type{ - BOOL, - STR, - INT, - FLOAT, - NDARRAY, - LIST, - FILE, - DICT - - } + private java.util.Map strVariables = new java.util.LinkedHashMap<>(); private java.util.Map intVariables = new java.util.LinkedHashMap<>(); private java.util.Map floatVariables = new java.util.LinkedHashMap<>(); private java.util.Map boolVariables = new java.util.LinkedHashMap<>(); - private java.util.Map ndVars = new java.util.LinkedHashMap<>(); - private java.util.Map listVariables = new java.util.LinkedHashMap<>(); - private java.util.Map fileVariables = new java.util.LinkedHashMap<>(); - private java.util.Map> dictVariables = new java.util.LinkedHashMap<>(); - private java.util.Map vars = new java.util.LinkedHashMap<>(); - private java.util.Map maps = new java.util.LinkedHashMap<>(); + private java.util.Map ndVars = new java.util.LinkedHashMap<>(); + private java.util.Map listVariables = new java.util.LinkedHashMap<>(); + private java.util.Map bytesVariables = new java.util.LinkedHashMap<>(); + private java.util.Map> dictVariables = new java.util.LinkedHashMap<>(); + private java.util.Map vars = new java.util.LinkedHashMap<>(); + private java.util.Map maps = new java.util.LinkedHashMap<>(); /** * Returns a copy of the variable * schema in this array without the values + * * @return an empty variables clone * with no values */ - public PythonVariables copySchema(){ + public PythonVariables copySchema() { PythonVariables ret = new PythonVariables(); - for (String varName: getVariables()){ - Type type = getType(varName); + for (String varName : getVariables()) { + PythonType type = getType(varName); ret.add(varName, type); } return ret; @@ -77,21 +73,19 @@ public class PythonVariables implements java.io.Serializable { * */ public PythonVariables() { - maps.put(PythonVariables.Type.BOOL, boolVariables); - maps.put(PythonVariables.Type.STR, strVariables); - maps.put(PythonVariables.Type.INT, intVariables); - maps.put(PythonVariables.Type.FLOAT, floatVariables); - maps.put(PythonVariables.Type.NDARRAY, ndVars); - maps.put(PythonVariables.Type.LIST, listVariables); - maps.put(PythonVariables.Type.FILE, fileVariables); - maps.put(PythonVariables.Type.DICT, dictVariables); + maps.put(PythonType.TypeName.BOOL, boolVariables); + maps.put(PythonType.TypeName.STR, strVariables); + maps.put(PythonType.TypeName.INT, intVariables); + maps.put(PythonType.TypeName.FLOAT, floatVariables); + maps.put(PythonType.TypeName.NDARRAY, ndVars); + maps.put(PythonType.TypeName.LIST, listVariables); + maps.put(PythonType.TypeName.DICT, dictVariables); + maps.put(PythonType.TypeName.BYTES, bytesVariables); } - /** - * * @return true if there are no variables. */ public boolean isEmpty() { @@ -100,12 +94,11 @@ public class PythonVariables implements java.io.Serializable { /** - * * @param name Name of the variable * @param type Type of the variable */ - public void add(String name, Type type){ - switch (type){ + public void add(String name, PythonType type) { + switch (type.getName()) { case BOOL: addBool(name); break; @@ -124,21 +117,21 @@ public class PythonVariables implements java.io.Serializable { case LIST: addList(name); break; - case FILE: - addFile(name); - break; case DICT: addDict(name); + break; + case BYTES: + addBytes(name); + break; } } /** - * - * @param name name of the variable - * @param type type of the variable + * @param name name of the variable + * @param type type of the variable * @param value value of the variable (must be instance of expected type) */ - public void add(String name, Type type, Object value) { + public void add(String name, PythonType type, Object value) throws PythonException { add(name, type); setValue(name, value); } @@ -148,21 +141,23 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ public void addDict(String name) { - vars.put(name, PythonVariables.Type.DICT); - dictVariables.put(name,null); + vars.put(name, PythonType.TypeName.DICT); + dictVariables.put(name, null); } /** * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ - public void addBool(String name){ - vars.put(name, PythonVariables.Type.BOOL); + public void addBool(String name) { + vars.put(name, PythonType.TypeName.BOOL); boolVariables.put(name, null); } @@ -170,10 +165,11 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ - public void addStr(String name){ - vars.put(name, PythonVariables.Type.STR); + public void addStr(String name) { + vars.put(name, PythonType.TypeName.STR); strVariables.put(name, null); } @@ -181,10 +177,11 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ - public void addInt(String name){ - vars.put(name, PythonVariables.Type.INT); + public void addInt(String name) { + vars.put(name, PythonType.TypeName.INT); intVariables.put(name, null); } @@ -192,10 +189,11 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ - public void addFloat(String name){ - vars.put(name, PythonVariables.Type.FLOAT); + public void addFloat(String name) { + vars.put(name, PythonType.TypeName.FLOAT); floatVariables.put(name, null); } @@ -203,10 +201,11 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ - public void addNDArray(String name){ - vars.put(name, PythonVariables.Type.NDARRAY); + public void addNDArray(String name) { + vars.put(name, PythonType.TypeName.NDARRAY); ndVars.put(name, null); } @@ -214,99 +213,109 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value + * * @param name the field to add */ - public void addList(String name){ - vars.put(name, PythonVariables.Type.LIST); + public void addList(String name) { + vars.put(name, PythonType.TypeName.LIST); listVariables.put(name, null); } - /** - * Add a null variable to - * the set of variables - * to describe the type but no value - * @param name the field to add - */ - public void addFile(String name){ - vars.put(name, PythonVariables.Type.FILE); - fileVariables.put(name, null); - } - /** * Add a boolean variable to * the set of variables - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addBool(String name, boolean value) { - vars.put(name, PythonVariables.Type.BOOL); + vars.put(name, PythonType.TypeName.BOOL); boolVariables.put(name, value); } /** * Add a string variable to * the set of variables - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addStr(String name, String value) { - vars.put(name, PythonVariables.Type.STR); + vars.put(name, PythonType.TypeName.STR); strVariables.put(name, value); } /** * Add an int variable to * the set of variables - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addInt(String name, int value) { - vars.put(name, PythonVariables.Type.INT); - intVariables.put(name, (long)value); + vars.put(name, PythonType.TypeName.INT); + intVariables.put(name, (long) value); } /** * Add a long variable to * the set of variables - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addInt(String name, long value) { - vars.put(name, PythonVariables.Type.INT); + vars.put(name, PythonType.TypeName.INT); intVariables.put(name, value); } /** * Add a double variable to * the set of variables - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addFloat(String name, double value) { - vars.put(name, PythonVariables.Type.FLOAT); + vars.put(name, PythonType.TypeName.FLOAT); floatVariables.put(name, value); } /** * Add a float variable to * the set of variables - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addFloat(String name, float value) { - vars.put(name, PythonVariables.Type.FLOAT); - floatVariables.put(name, (double)value); + vars.put(name, PythonType.TypeName.FLOAT); + floatVariables.put(name, (double) value); } /** * Add a null variable to * the set of variables * to describe the type but no value - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addNDArray(String name, NumpyArray value) { - vars.put(name, PythonVariables.Type.NDARRAY); + vars.put(name, PythonType.TypeName.NDARRAY); + ndVars.put(name, value.getNd4jArray()); + } + + /** + * Add a null variable to + * the set of variables + * to describe the type but no value + * + * @param name the field to add + * @param value the value to add + */ + public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) { + vars.put(name, PythonType.TypeName.NDARRAY); ndVars.put(name, value); } @@ -314,117 +323,63 @@ public class PythonVariables implements java.io.Serializable { * Add a null variable to * the set of variables * to describe the type but no value - * @param name the field to add - * @param value the value to add - */ - public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) { - vars.put(name, PythonVariables.Type.NDARRAY); - ndVars.put(name, new NumpyArray(value)); - } - - /** - * Add a null variable to - * the set of variables - * to describe the type but no value - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addList(String name, Object[] value) { - vars.put(name, PythonVariables.Type.LIST); - listVariables.put(name, value); + vars.put(name, PythonType.TypeName.LIST); + listVariables.put(name, Arrays.asList(value)); } - + /** * Add a null variable to * the set of variables * to describe the type but no value - * @param name the field to add - * @param value the value to add - */ - public void addFile(String name, String value) { - vars.put(name, PythonVariables.Type.FILE); - fileVariables.put(name, value); - } - - - /** - * Add a null variable to - * the set of variables - * to describe the type but no value - * @param name the field to add + * + * @param name the field to add * @param value the value to add */ public void addDict(String name, java.util.Map value) { - vars.put(name, PythonVariables.Type.DICT); + vars.put(name, PythonType.TypeName.DICT); dictVariables.put(name, value); } + + + public void addBytes(String name){ + vars.put(name, PythonType.TypeName.BYTES); + bytesVariables.put(name, null); + } + + public void addBytes(String name, BytePointer value){ + vars.put(name, PythonType.TypeName.BYTES); + bytesVariables.put(name, value); + } + +// public void addBytes(String name, ByteBuffer value){ +// Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress((value.address()); +// BytePointer bp = new BytePointer(ptr); +// addBytes(name, bp); +// } /** - * - * @param name name of the variable + * @param name name of the variable * @param value new value for the variable */ - public void setValue(String name, Object value) { - Type type = vars.get(name); - if (type == PythonVariables.Type.BOOL){ - boolVariables.put(name, (Boolean)value); - } - else if (type == PythonVariables.Type.INT){ - Number number = (Number) value; - intVariables.put(name, number.longValue()); - } - else if (type == PythonVariables.Type.FLOAT){ - Number number = (Number) value; - floatVariables.put(name, number.doubleValue()); - } - else if (type == PythonVariables.Type.NDARRAY){ - if (value instanceof NumpyArray){ - ndVars.put(name, (NumpyArray)value); - } - else if (value instanceof org.nd4j.linalg.api.ndarray.INDArray) { - ndVars.put(name, new NumpyArray((org.nd4j.linalg.api.ndarray.INDArray) value)); - } - else{ - throw new RuntimeException("Unsupported type: " + value.getClass().toString()); - } - } - else if (type == PythonVariables.Type.LIST) { - if (value instanceof java.util.List) { - value = ((java.util.List) value).toArray(); - listVariables.put(name, (Object[]) value); - } - else if(value instanceof org.json.JSONArray) { - org.json.JSONArray jsonArray = (org.json.JSONArray) value; - Object[] copyArr = new Object[jsonArray.length()]; - for(int i = 0; i < copyArr.length; i++) { - copyArr[i] = jsonArray.get(i); - } - listVariables.put(name, copyArr); - - } - else { - listVariables.put(name, (Object[]) value); - } - } - else if(type == PythonVariables.Type.DICT) { - dictVariables.put(name,(java.util.Map) value); - } - else if (type == PythonVariables.Type.FILE){ - fileVariables.put(name, (String)value); - } - else{ - strVariables.put(name, (String)value); - } + public void setValue(String name, Object value) throws PythonException { + PythonType.TypeName type = vars.get(name); + maps.get(type).put(name, PythonType.valueOf(type).convert(value)); } /** * Do a general object lookup. - * The look up will happen relative to the {@link Type} + * The look up will happen relative to the {@link PythonType} * of variable is described in the + * * @param name the name of the variable to get * @return teh value for the variable with the given name */ public Object getValue(String name) { - Type type = vars.get(name); + PythonType.TypeName type = vars.get(name); java.util.Map map = maps.get(type); return map.get(name); } @@ -432,6 +387,7 @@ public class PythonVariables implements java.io.Serializable { /** * Returns a boolean variable with the given name. + * * @param name the variable name to get the value for * @return the retrieved boolean value */ @@ -440,80 +396,78 @@ public class PythonVariables implements java.io.Serializable { } /** - * * @param name the variable name * @return the dictionary value */ - public java.util.Map getDictValue(String name) { + public java.util.Map getDictValue(String name) { return dictVariables.get(name); } /** - /** + * /** * * @param name the variable name * @return the string value */ - public String getStrValue(String name){ + public String getStrValue(String name) { return strVariables.get(name); } /** - * * @param name the variable name * @return the long value */ - public Long getIntValue(String name){ + public Long getIntValue(String name) { return intVariables.get(name); } /** - * * @param name the variable name * @return the float value */ - public Double getFloatValue(String name){ + public Double getFloatValue(String name) { return floatVariables.get(name); } /** - * * @param name the variable name * @return the numpy array value */ - public NumpyArray getNDArrayValue(String name){ + public INDArray getNDArrayValue(String name) { return ndVars.get(name); } /** - * * @param name the variable name * @return the list value as an object array */ - public Object[] getListValue(String name){ + public List getListValue(String name) { return listVariables.get(name); } /** - * * @param name the variable name - * @return the value of the given file name + * @return the bytes value as a BytePointer */ - public String getFileValue(String name){ - return fileVariables.get(name); - } - + public BytePointer getBytesValue(String name){return bytesVariables.get(name);} /** * Returns the type for the given variable name + * * @param name the name of the variable to get the type for * @return the type for the given variable */ - public Type getType(String name){ - return vars.get(name); + public PythonType getType(String name){ + try{ + return PythonType.valueOf(vars.get(name)); // will never fail + }catch (Exception e) + { + throw new RuntimeException(e); + } } /** * Get all the variables present as a string array + * * @return the variable names for this variable sset */ public String[] getVariables() { @@ -524,11 +478,12 @@ public class PythonVariables implements java.io.Serializable { /** * This variables set as its json representation (an array of json objects) + * * @return the json array output */ - public org.json.JSONArray toJSON(){ + public org.json.JSONArray toJSON() { org.json.JSONArray arr = new org.json.JSONArray(); - for (String varName: getVariables()){ + for (String varName : getVariables()) { org.json.JSONObject var = new org.json.JSONObject(); var.put("name", varName); String varType = getType(varName).toString(); @@ -542,13 +497,14 @@ public class PythonVariables implements java.io.Serializable { * Create a schema from a map. * This is an empty PythonVariables * that just contains names and types with no values + * * @param inputTypes the input types to convert * @return the schema from the given map */ - public static PythonVariables schemaFromMap(java.util.Map inputTypes) { + public static PythonVariables schemaFromMap(java.util.Map inputTypes) throws Exception{ PythonVariables ret = new PythonVariables(); - for(java.util.Map.Entry entry : inputTypes.entrySet()) { - ret.add(entry.getKey(), PythonVariables.Type.valueOf(entry.getValue())); + for (java.util.Map.Entry entry : inputTypes.entrySet()) { + ret.add(entry.getKey(), PythonType.valueOf(entry.getValue())); } return ret; @@ -557,39 +513,17 @@ public class PythonVariables implements java.io.Serializable { /** * Get the python variable state relative to the * input json array + * * @param jsonArray the input json array * @return the python variables based on the input json array */ - public static PythonVariables fromJSON(org.json.JSONArray jsonArray){ + public static PythonVariables fromJSON(org.json.JSONArray jsonArray) { PythonVariables pyvars = new PythonVariables(); for (int i = 0; i < jsonArray.length(); i++) { org.json.JSONObject input = (org.json.JSONObject) jsonArray.get(i); - String varName = (String)input.get("name"); - String varType = (String)input.get("type"); - if (varType.equals("BOOL")) { - pyvars.addBool(varName); - } - else if (varType.equals("INT")) { - pyvars.addInt(varName); - } - else if (varType.equals("FlOAT")){ - pyvars.addFloat(varName); - } - else if (varType.equals("STR")) { - pyvars.addStr(varName); - } - else if (varType.equals("LIST")) { - pyvars.addList(varName); - } - else if (varType.equals("FILE")){ - pyvars.addFile(varName); - } - else if (varType.equals("NDARRAY")) { - pyvars.addNDArray(varName); - } - else if(varType.equals("DICT")) { - pyvars.addDict(varName); - } + String varName = (String) input.get("name"); + String varType = (String) input.get("type"); + pyvars.maps.get(PythonType.TypeName.valueOf(varType)).put(varName, null); } return pyvars; diff --git a/datavec/datavec-python/src/main/resources/pythonexec/__init__.py b/datavec/datavec-python/src/main/resources/pythonexec/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/datavec/datavec-python/src/main/resources/pythonexec/clear_vars.py b/datavec/datavec-python/src/main/resources/pythonexec/clear_vars.py deleted file mode 100644 index 239ad694f..000000000 --- a/datavec/datavec-python/src/main/resources/pythonexec/clear_vars.py +++ /dev/null @@ -1,5 +0,0 @@ -#See: https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script -import sys -this = sys.modules[__name__] -for n in dir(): - if n[0]!='_': delattr(this, n) \ No newline at end of file diff --git a/datavec/datavec-python/src/main/resources/pythonexec/input_code.py b/datavec/datavec-python/src/main/resources/pythonexec/input_code.py deleted file mode 100644 index 92ea40ac5..000000000 --- a/datavec/datavec-python/src/main/resources/pythonexec/input_code.py +++ /dev/null @@ -1 +0,0 @@ -loc = {} diff --git a/datavec/datavec-python/src/main/resources/pythonexec/outputcode.py b/datavec/datavec-python/src/main/resources/pythonexec/outputcode.py deleted file mode 100644 index 7f70a0e93..000000000 --- a/datavec/datavec-python/src/main/resources/pythonexec/outputcode.py +++ /dev/null @@ -1,20 +0,0 @@ - -def __is_numpy_array(x): - return str(type(x))== "" - -def maybe_serialize_ndarray_metadata(x): - return serialize_ndarray_metadata(x) if __is_numpy_array(x) else x - - -def serialize_ndarray_metadata(x): - return {"address": x.__array_interface__['data'][0], - "shape": x.shape, - "strides": x.strides, - "dtype": str(x.dtype), - "_is_numpy_array": True} if __is_numpy_array(x) else x - - -def is_json_ready(key, value): - return key is not 'f2' and not inspect.ismodule(value) \ - and not hasattr(value, '__call__') - diff --git a/datavec/datavec-python/src/main/resources/pythonexec/patch0.py b/datavec/datavec-python/src/main/resources/pythonexec/patch0.py deleted file mode 100644 index d2ed3d5e5..000000000 --- a/datavec/datavec-python/src/main/resources/pythonexec/patch0.py +++ /dev/null @@ -1,202 +0,0 @@ -#patch - -"""Implementation of __array_function__ overrides from NEP-18.""" -import collections -import functools -import os - -from numpy.core._multiarray_umath import ( - add_docstring, implement_array_function, _get_implementing_args) -from numpy.compat._inspect import getargspec - - -ENABLE_ARRAY_FUNCTION = bool( - int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0))) - - -ARRAY_FUNCTION_ENABLED = ENABLE_ARRAY_FUNCTION # backward compat - - -_add_docstring = add_docstring - - -def add_docstring(*args): - try: - _add_docstring(*args) - except: - pass - - -add_docstring( - implement_array_function, - """ - Implement a function with checks for __array_function__ overrides. - - All arguments are required, and can only be passed by position. - - Arguments - --------- - implementation : function - Function that implements the operation on NumPy array without - overrides when called like ``implementation(*args, **kwargs)``. - public_api : function - Function exposed by NumPy's public API originally called like - ``public_api(*args, **kwargs)`` on which arguments are now being - checked. - relevant_args : iterable - Iterable of arguments to check for __array_function__ methods. - args : tuple - Arbitrary positional arguments originally passed into ``public_api``. - kwargs : dict - Arbitrary keyword arguments originally passed into ``public_api``. - - Returns - ------- - Result from calling ``implementation()`` or an ``__array_function__`` - method, as appropriate. - - Raises - ------ - TypeError : if no implementation is found. - """) - - -# exposed for testing purposes; used internally by implement_array_function -add_docstring( - _get_implementing_args, - """ - Collect arguments on which to call __array_function__. - - Parameters - ---------- - relevant_args : iterable of array-like - Iterable of possibly array-like arguments to check for - __array_function__ methods. - - Returns - ------- - Sequence of arguments with __array_function__ methods, in the order in - which they should be called. - """) - - -ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') - - -def verify_matching_signatures(implementation, dispatcher): - """Verify that a dispatcher function has the right signature.""" - implementation_spec = ArgSpec(*getargspec(implementation)) - dispatcher_spec = ArgSpec(*getargspec(dispatcher)) - - if (implementation_spec.args != dispatcher_spec.args or - implementation_spec.varargs != dispatcher_spec.varargs or - implementation_spec.keywords != dispatcher_spec.keywords or - (bool(implementation_spec.defaults) != - bool(dispatcher_spec.defaults)) or - (implementation_spec.defaults is not None and - len(implementation_spec.defaults) != - len(dispatcher_spec.defaults))): - raise RuntimeError('implementation and dispatcher for %s have ' - 'different function signatures' % implementation) - - if implementation_spec.defaults is not None: - if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults): - raise RuntimeError('dispatcher functions can only use None for ' - 'default argument values') - - -def set_module(module): - """Decorator for overriding __module__ on a function or class. - - Example usage:: - - @set_module('numpy') - def example(): - pass - - assert example.__module__ == 'numpy' - """ - def decorator(func): - if module is not None: - func.__module__ = module - return func - return decorator - - -def array_function_dispatch(dispatcher, module=None, verify=True, - docs_from_dispatcher=False): - """Decorator for adding dispatch with the __array_function__ protocol. - - See NEP-18 for example usage. - - Parameters - ---------- - dispatcher : callable - Function that when called like ``dispatcher(*args, **kwargs)`` with - arguments from the NumPy function call returns an iterable of - array-like arguments to check for ``__array_function__``. - module : str, optional - __module__ attribute to set on new function, e.g., ``module='numpy'``. - By default, module is copied from the decorated function. - verify : bool, optional - If True, verify the that the signature of the dispatcher and decorated - function signatures match exactly: all required and optional arguments - should appear in order with the same names, but the default values for - all optional arguments should be ``None``. Only disable verification - if the dispatcher's signature needs to deviate for some particular - reason, e.g., because the function has a signature like - ``func(*args, **kwargs)``. - docs_from_dispatcher : bool, optional - If True, copy docs from the dispatcher function onto the dispatched - function, rather than from the implementation. This is useful for - functions defined in C, which otherwise don't have docstrings. - - Returns - ------- - Function suitable for decorating the implementation of a NumPy function. - """ - - if not ENABLE_ARRAY_FUNCTION: - # __array_function__ requires an explicit opt-in for now - def decorator(implementation): - if module is not None: - implementation.__module__ = module - if docs_from_dispatcher: - add_docstring(implementation, dispatcher.__doc__) - return implementation - return decorator - - def decorator(implementation): - if verify: - verify_matching_signatures(implementation, dispatcher) - - if docs_from_dispatcher: - add_docstring(implementation, dispatcher.__doc__) - - @functools.wraps(implementation) - def public_api(*args, **kwargs): - relevant_args = dispatcher(*args, **kwargs) - return implement_array_function( - implementation, public_api, relevant_args, args, kwargs) - - if module is not None: - public_api.__module__ = module - - # TODO: remove this when we drop Python 2 support (functools.wraps) - # adds __wrapped__ automatically in later versions) - public_api.__wrapped__ = implementation - - return public_api - - return decorator - - -def array_function_from_dispatcher( - implementation, module=None, verify=True, docs_from_dispatcher=True): - """Like array_function_dispatcher, but with function arguments flipped.""" - - def decorator(dispatcher): - return array_function_dispatch( - dispatcher, module, verify=verify, - docs_from_dispatcher=docs_from_dispatcher)(implementation) - return decorator diff --git a/datavec/datavec-python/src/main/resources/pythonexec/patch1.py b/datavec/datavec-python/src/main/resources/pythonexec/patch1.py deleted file mode 100644 index 890852bbc..000000000 --- a/datavec/datavec-python/src/main/resources/pythonexec/patch1.py +++ /dev/null @@ -1,172 +0,0 @@ -#patch 1 - -""" -======================== -Random Number Generation -======================== - -==================== ========================================================= -Utility functions -============================================================================== -random_sample Uniformly distributed floats over ``[0, 1)``. -random Alias for `random_sample`. -bytes Uniformly distributed random bytes. -random_integers Uniformly distributed integers in a given range. -permutation Randomly permute a sequence / generate a random sequence. -shuffle Randomly permute a sequence in place. -seed Seed the random number generator. -choice Random sample from 1-D array. - -==================== ========================================================= - -==================== ========================================================= -Compatibility functions -============================================================================== -rand Uniformly distributed values. -randn Normally distributed values. -ranf Uniformly distributed floating point numbers. -randint Uniformly distributed integers in a given range. -==================== ========================================================= - -==================== ========================================================= -Univariate distributions -============================================================================== -beta Beta distribution over ``[0, 1]``. -binomial Binomial distribution. -chisquare :math:`\\chi^2` distribution. -exponential Exponential distribution. -f F (Fisher-Snedecor) distribution. -gamma Gamma distribution. -geometric Geometric distribution. -gumbel Gumbel distribution. -hypergeometric Hypergeometric distribution. -laplace Laplace distribution. -logistic Logistic distribution. -lognormal Log-normal distribution. -logseries Logarithmic series distribution. -negative_binomial Negative binomial distribution. -noncentral_chisquare Non-central chi-square distribution. -noncentral_f Non-central F distribution. -normal Normal / Gaussian distribution. -pareto Pareto distribution. -poisson Poisson distribution. -power Power distribution. -rayleigh Rayleigh distribution. -triangular Triangular distribution. -uniform Uniform distribution. -vonmises Von Mises circular distribution. -wald Wald (inverse Gaussian) distribution. -weibull Weibull distribution. -zipf Zipf's distribution over ranked data. -==================== ========================================================= - -==================== ========================================================= -Multivariate distributions -============================================================================== -dirichlet Multivariate generalization of Beta distribution. -multinomial Multivariate generalization of the binomial distribution. -multivariate_normal Multivariate generalization of the normal distribution. -==================== ========================================================= - -==================== ========================================================= -Standard distributions -============================================================================== -standard_cauchy Standard Cauchy-Lorentz distribution. -standard_exponential Standard exponential distribution. -standard_gamma Standard Gamma distribution. -standard_normal Standard normal distribution. -standard_t Standard Student's t-distribution. -==================== ========================================================= - -==================== ========================================================= -Internal functions -============================================================================== -get_state Get tuple representing internal state of generator. -set_state Set state of generator. -==================== ========================================================= - -""" -from __future__ import division, absolute_import, print_function - -import warnings - -__all__ = [ - 'beta', - 'binomial', - 'bytes', - 'chisquare', - 'choice', - 'dirichlet', - 'exponential', - 'f', - 'gamma', - 'geometric', - 'get_state', - 'gumbel', - 'hypergeometric', - 'laplace', - 'logistic', - 'lognormal', - 'logseries', - 'multinomial', - 'multivariate_normal', - 'negative_binomial', - 'noncentral_chisquare', - 'noncentral_f', - 'normal', - 'pareto', - 'permutation', - 'poisson', - 'power', - 'rand', - 'randint', - 'randn', - 'random_integers', - 'random_sample', - 'rayleigh', - 'seed', - 'set_state', - 'shuffle', - 'standard_cauchy', - 'standard_exponential', - 'standard_gamma', - 'standard_normal', - 'standard_t', - 'triangular', - 'uniform', - 'vonmises', - 'wald', - 'weibull', - 'zipf' -] - -with warnings.catch_warnings(): - warnings.filterwarnings("ignore", message="numpy.ndarray size changed") - try: - from .mtrand import * - # Some aliases: - ranf = random = sample = random_sample - __all__.extend(['ranf', 'random', 'sample']) - except: - warnings.warn("numpy.random is not available when using multiple interpreters!") - - - -def __RandomState_ctor(): - """Return a RandomState instance. - - This function exists solely to assist (un)pickling. - - Note that the state of the RandomState returned here is irrelevant, as this function's - entire purpose is to return a newly allocated RandomState whose state pickle can set. - Consequently the RandomState returned by this function is a freshly allocated copy - with a seed=0. - - See https://github.com/numpy/numpy/issues/4763 for a detailed discussion - - """ - return RandomState(seed=0) - -from numpy._pytesttester import PytestTester -test = PytestTester(__name__) -del PytestTester diff --git a/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py b/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py index dbdceff0e..1509610c7 100644 --- a/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py +++ b/datavec/datavec-python/src/main/resources/pythonexec/pythonexec.py @@ -3,13 +3,13 @@ import traceback import json import inspect - +__python_exception__ = "" try: - pass sys.stdout.flush() sys.stderr.flush() except Exception as ex: + __python_exception__ = ex try: exc_info = sys.exc_info() finally: diff --git a/datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py b/datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py deleted file mode 100644 index ac6f5b1c1..000000000 --- a/datavec/datavec-python/src/main/resources/pythonexec/serialize_array.py +++ /dev/null @@ -1,50 +0,0 @@ -def __is_numpy_array(x): - return str(type(x))== "" - -def __maybe_serialize_ndarray_metadata(x): - return __serialize_ndarray_metadata(x) if __is_numpy_array(x) else x - - -def __serialize_ndarray_metadata(x): - return {"address": x.__array_interface__['data'][0], - "shape": x.shape, - "strides": x.strides, - "dtype": str(x.dtype), - "_is_numpy_array": True} if __is_numpy_array(x) else x - - -def __serialize_list(x): - import json - return json.dumps(__recursive_serialize_list(x)) - - -def __serialize_dict(x): - import json - return json.dumps(__recursive_serialize_dict(x)) - -def __recursive_serialize_list(x): - out = [] - for i in x: - if __is_numpy_array(i): - out.append(__serialize_ndarray_metadata(i)) - elif isinstance(i, (list, tuple)): - out.append(__recursive_serialize_list(i)) - elif isinstance(i, dict): - out.append(__recursive_serialize_dict(i)) - else: - out.append(i) - return out - -def __recursive_serialize_dict(x): - out = {} - for k in x: - v = x[k] - if __is_numpy_array(v): - out[k] = __serialize_ndarray_metadata(v) - elif isinstance(v, (list, tuple)): - out[k] = __recursive_serialize_list(v) - elif isinstance(v, dict): - out[k] = __recursive_serialize_dict(v) - else: - out[k] = v - return out \ No newline at end of file diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonContextManager.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonContextManager.java new file mode 100644 index 000000000..4abad139f --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonContextManager.java @@ -0,0 +1,87 @@ + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + + +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import javax.annotation.concurrent.NotThreadSafe; + +@NotThreadSafe +public class TestPythonContextManager { + + @Test + public void testInt() throws Exception{ + Python.setContext("context1"); + Python.exec("a = 1"); + Python.setContext("context2"); + Python.exec("a = 2"); + Python.setContext("context3"); + Python.exec("a = 3"); + + + Python.setContext("context1"); + Assert.assertEquals(1, PythonExecutioner.getVariable("a").toInt()); + + Python.setContext("context2"); + Assert.assertEquals(2, PythonExecutioner.getVariable("a").toInt()); + + Python.setContext("context3"); + Assert.assertEquals(3, PythonExecutioner.getVariable("a").toInt()); + + PythonContextManager.deleteNonMainContexts(); + } + + @Test + public void testNDArray() throws Exception{ + Python.setContext("context1"); + Python.exec("import numpy as np"); + Python.exec("a = np.zeros((3,2)) + 1"); + + Python.setContext("context2"); + Python.exec("import numpy as np"); + Python.exec("a = np.zeros((3,2)) + 2"); + + Python.setContext("context3"); + Python.exec("import numpy as np"); + Python.exec("a = np.zeros((3,2)) + 3"); + + Python.setContext("context1"); + Python.exec("a += 1"); + + Python.setContext("context2"); + Python.exec("a += 2"); + + Python.setContext("context3"); + Python.exec("a += 3"); + + INDArray arr = Nd4j.create(DataType.DOUBLE, 3, 2); + Python.setContext("context1"); + Assert.assertEquals(arr.add(2), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray()); + + Python.setContext("context2"); + Assert.assertEquals(arr.add(4), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray()); + + Python.setContext("context3"); + Assert.assertEquals(arr.add(6), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray()); + } + +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonDict.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonDict.java new file mode 100644 index 000000000..bc06269f1 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonDict.java @@ -0,0 +1,64 @@ + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + + +import lombok.var; +import org.json.JSONArray; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +@javax.annotation.concurrent.NotThreadSafe +public class TestPythonDict { + + @Test + public void testPythonDictFromMap() throws Exception{ + Map map = new HashMap<>(); + map.put("a", 1); + map.put("b", "a"); + map.put("1", Arrays.asList(1, 2, 3, "4", Arrays.asList("x", 2.3))); + Map innerMap = new HashMap<>(); + innerMap.put("k", 32); + map.put("inner", innerMap); + map.put("ndarray", Nd4j.linspace(1, 4, 4)); + innerMap.put("ndarray", Nd4j.linspace(5, 8, 4)); + PythonObject dict = new PythonObject(map); + assertEquals(map.size(), Python.len(dict).toInt()); + assertEquals("{'a': 1, '1': [1, 2, 3, '4', ['" + + "x', 2.3]], 'b': 'a', 'inner': {'k': 32," + + " 'ndarray': array([5., 6., 7., 8.], dty" + + "pe=float32)}, 'ndarray': array([1., 2., " + + "3., 4.], dtype=float32)}", + dict.toString()); + Map map2 = dict.toMap(); + PythonObject dict2 = new PythonObject(map2); + assertEquals(dict.toString(), dict2.toString()); + + + } + +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java deleted file mode 100644 index 435babf7c..000000000 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutionSandbox.java +++ /dev/null @@ -1,75 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.python; - - -import org.junit.Assert; -import org.junit.Test; - -@javax.annotation.concurrent.NotThreadSafe -public class TestPythonExecutionSandbox { - - @Test - public void testInt(){ - PythonExecutioner.setInterpreter("interp1"); - PythonExecutioner.exec("a = 1"); - PythonExecutioner.setInterpreter("interp2"); - PythonExecutioner.exec("a = 2"); - PythonExecutioner.setInterpreter("interp3"); - PythonExecutioner.exec("a = 3"); - - - PythonExecutioner.setInterpreter("interp1"); - Assert.assertEquals(1, PythonExecutioner.evalInteger("a")); - - PythonExecutioner.setInterpreter("interp2"); - Assert.assertEquals(2, PythonExecutioner.evalInteger("a")); - - PythonExecutioner.setInterpreter("interp3"); - Assert.assertEquals(3, PythonExecutioner.evalInteger("a")); - } - - @Test - public void testNDArray(){ - PythonExecutioner.setInterpreter("main"); - PythonExecutioner.exec("import numpy as np"); - PythonExecutioner.exec("a = np.zeros(5)"); - - PythonExecutioner.setInterpreter("main"); - //PythonExecutioner.exec("import numpy as np"); - PythonExecutioner.exec("a = np.zeros(5)"); - - PythonExecutioner.setInterpreter("main"); - PythonExecutioner.exec("a += 2"); - - PythonExecutioner.setInterpreter("main"); - PythonExecutioner.exec("a += 3"); - - PythonExecutioner.setInterpreter("main"); - //PythonExecutioner.exec("import numpy as np"); - // PythonExecutioner.exec("a = np.zeros(5)"); - - PythonExecutioner.setInterpreter("main"); - Assert.assertEquals(25, PythonExecutioner.evalNdArray("a").getNd4jArray().sum().getDouble(), 1e-5); - } - - @Test - public void testNumpyRandom(){ - PythonExecutioner.setInterpreter("main"); - PythonExecutioner.exec("import numpy as np; print(np.random.randint(5))"); - } -} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java index c8e67febb..bb436e808 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonExecutioner.java @@ -15,13 +15,17 @@ ******************************************************************************/ package org.datavec.python; + +import org.bytedeco.javacpp.BytePointer; import org.junit.Assert; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; @javax.annotation.concurrent.NotThreadSafe @@ -29,12 +33,12 @@ public class TestPythonExecutioner { @org.junit.Test - public void testPythonSysVersion() { - PythonExecutioner.exec("import sys; print(sys.version)"); + public void testPythonSysVersion() throws PythonException { + Python.exec("import sys; print(sys.version)"); } @Test - public void testStr() throws Exception{ + public void testStr() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -46,7 +50,7 @@ public class TestPythonExecutioner { String code = "z = x + ' ' + y"; - PythonExecutioner.exec(code, pyInputs, pyOutputs); + Python.exec(code, pyInputs, pyOutputs); String z = pyOutputs.getStrValue("z"); @@ -56,7 +60,7 @@ public class TestPythonExecutioner { } @Test - public void testInt()throws Exception{ + public void testInt() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -68,7 +72,7 @@ public class TestPythonExecutioner { pyOutputs.addInt("z"); - PythonExecutioner.exec(code, pyInputs, pyOutputs); + Python.exec(code, pyInputs, pyOutputs); long z = pyOutputs.getIntValue("z"); @@ -77,7 +81,7 @@ public class TestPythonExecutioner { } @Test - public void testList() throws Exception{ + public void testList() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -92,30 +96,28 @@ public class TestPythonExecutioner { pyOutputs.addList("z"); - PythonExecutioner.exec(code, pyInputs, pyOutputs); + Python.exec(code, pyInputs, pyOutputs); - Object[] z = pyOutputs.getListValue("z"); + Object[] z = pyOutputs.getListValue("z").toArray(); Assert.assertEquals(z.length, x.length + y.length); for (int i = 0; i < x.length; i++) { - if(x[i] instanceof Number) { + if (x[i] instanceof Number) { Number xNum = (Number) x[i]; Number zNum = (Number) z[i]; Assert.assertEquals(xNum.intValue(), zNum.intValue()); - } - else { + } else { Assert.assertEquals(x[i], z[i]); } } - for (int i = 0; i < y.length; i++){ - if(y[i] instanceof Number) { + for (int i = 0; i < y.length; i++) { + if (y[i] instanceof Number) { Number yNum = (Number) y[i]; Number zNum = (Number) z[x.length + i]; Assert.assertEquals(yNum.intValue(), zNum.intValue()); - } - else { + } else { Assert.assertEquals(y[i], z[x.length + i]); } @@ -125,7 +127,7 @@ public class TestPythonExecutioner { } @Test - public void testNDArrayFloat()throws Exception{ + public void testNDArrayFloat() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -135,8 +137,8 @@ public class TestPythonExecutioner { String code = "z = x + y"; - PythonExecutioner.exec(code, pyInputs, pyOutputs); - INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); + Python.exec(code, pyInputs, pyOutputs); + INDArray z = pyOutputs.getNDArrayValue("z"); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); @@ -144,12 +146,13 @@ public class TestPythonExecutioner { } @Test - public void testTensorflowCustomAnaconda() { - PythonExecutioner.exec("import tensorflow as tf"); + @Ignore + public void testTensorflowCustomAnaconda() throws PythonException { + Python.exec("import tensorflow as tf"); } @Test - public void testNDArrayDouble()throws Exception { + public void testNDArrayDouble() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -159,14 +162,14 @@ public class TestPythonExecutioner { String code = "z = x + y"; - PythonExecutioner.exec(code, pyInputs, pyOutputs); - INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); + Python.exec(code, pyInputs, pyOutputs); + INDArray z = pyOutputs.getNDArrayValue("z"); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); } @Test - public void testNDArrayShort()throws Exception{ + public void testNDArrayShort() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -176,15 +179,15 @@ public class TestPythonExecutioner { String code = "z = x + y"; - PythonExecutioner.exec(code, pyInputs, pyOutputs); - INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); + Python.exec(code, pyInputs, pyOutputs); + INDArray z = pyOutputs.getNDArrayValue("z"); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); } @Test - public void testNDArrayInt()throws Exception{ + public void testNDArrayInt() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -194,15 +197,15 @@ public class TestPythonExecutioner { String code = "z = x + y"; - PythonExecutioner.exec(code, pyInputs, pyOutputs); - INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); + Python.exec(code, pyInputs, pyOutputs); + INDArray z = pyOutputs.getNDArrayValue("z"); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); } @Test - public void testNDArrayLong()throws Exception{ + public void testNDArrayLong() throws Exception { PythonVariables pyInputs = new PythonVariables(); PythonVariables pyOutputs = new PythonVariables(); @@ -212,12 +215,91 @@ public class TestPythonExecutioner { String code = "z = x + y"; - PythonExecutioner.exec(code, pyInputs, pyOutputs); - INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray(); + Python.exec(code, pyInputs, pyOutputs); + INDArray z = pyOutputs.getNDArrayValue("z"); Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5); + } + @Test + public void testByteBufferInput() throws Exception{ + //ByteBuffer buff = ByteBuffer.allocateDirect(3); + INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE); + buff.putScalar(0, 97); // a + buff.putScalar(1, 98); // b + buff.putScalar(2, 99); // c + + + PythonVariables pyInputs = new PythonVariables(); + pyInputs.addBytes("buff", new BytePointer(buff.data().pointer())); + + PythonVariables pyOutputs= new PythonVariables(); + pyOutputs.addStr("out"); + + String code = "out = buff.decode()"; + Python.exec(code, pyInputs, pyOutputs); + Assert.assertEquals("abc", pyOutputs.getStrValue("out")); + + } + + @Test + public void testByteBufferOutputNoCopy() throws Exception{ + INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE); + buff.putScalar(0, 97); // a + buff.putScalar(1, 98); // b + buff.putScalar(2, 99); // c + + + PythonVariables pyInputs = new PythonVariables(); + pyInputs.addBytes("buff", new BytePointer(buff.data().pointer())); + + PythonVariables pyOutputs = new PythonVariables(); + pyOutputs.addBytes("buff"); // same name as input, because inplace update + + String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97"; + Python.exec(code, pyInputs, pyOutputs); + Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString()); + } + + @Test + public void testByteBufferOutputWithCopy() throws Exception{ + INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE); + buff.putScalar(0, 97); // a + buff.putScalar(1, 98); // b + buff.putScalar(2, 99); // c + + + PythonVariables pyInputs = new PythonVariables(); + pyInputs.addBytes("buff", new BytePointer(buff.data().pointer())); + + PythonVariables pyOutputs = new PythonVariables(); + pyOutputs.addBytes("out"); + + String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97\nout=bytes(buff)"; + Python.exec(code, pyInputs, pyOutputs); + Assert.assertEquals("cba", pyOutputs.getBytesValue("out").getString()); + } + @Test + public void testBadCode() throws Exception{ + Python.setContext("badcode"); + PythonVariables pyInputs = new PythonVariables(); + PythonVariables pyOutputs = new PythonVariables(); + + pyInputs.addNDArray("x", Nd4j.zeros(DataType.LONG, 2, 3)); + pyInputs.addNDArray("y", Nd4j.ones(DataType.LONG, 2, 3)); + pyOutputs.addNDArray("z"); + + String code = "z = x + a"; + + try{ + Python.exec(code, pyInputs, pyOutputs); + fail("No exception thrown"); + } catch (PythonException pe ){ + Assert.assertEquals("NameError: name 'a' is not defined", pe.getMessage()); + } + + Python.setMainContext(); } } diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonJob.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonJob.java new file mode 100644 index 000000000..bb79d2837 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonJob.java @@ -0,0 +1,326 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; +import org.junit.Assert; +import org.junit.Test; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import static org.junit.Assert.assertEquals; + + +@javax.annotation.concurrent.NotThreadSafe +public class TestPythonJob { + + @Test + public void testPythonJobBasic() throws Exception{ + PythonContextManager.deleteNonMainContexts(); + + String code = "c = a + b"; + PythonJob job = new PythonJob("job1", code, false); + + PythonVariables inputs = new PythonVariables(); + inputs.addInt("a", 2); + inputs.addInt("b", 3); + + PythonVariables outputs = new PythonVariables(); + outputs.addInt("c"); + + job.exec(inputs, outputs); + + assertEquals(5L, (long)outputs.getIntValue("c")); + + inputs = new PythonVariables(); + inputs.addFloat("a", 3.0); + inputs.addFloat("b", 4.0); + + outputs = new PythonVariables(); + outputs.addFloat("c"); + + + job.exec(inputs, outputs); + + assertEquals(7.0, outputs.getFloatValue("c"), 1e-5); + + + inputs = new PythonVariables(); + inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4)); + inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5)); + + outputs = new PythonVariables(); + outputs.addNDArray("c"); + + + job.exec(inputs, outputs); + + assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c")); + } + + @Test + public void testPythonJobReturnAllVariables()throws Exception{ + PythonContextManager.deleteNonMainContexts(); + + String code = "c = a + b"; + PythonJob job = new PythonJob("job1", code, false); + + PythonVariables inputs = new PythonVariables(); + inputs.addInt("a", 2); + inputs.addInt("b", 3); + + + PythonVariables outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(5L, (long)outputs.getIntValue("c")); + + inputs = new PythonVariables(); + inputs.addFloat("a", 3.0); + inputs.addFloat("b", 4.0); + + outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(7.0, outputs.getFloatValue("c"), 1e-5); + + + inputs = new PythonVariables(); + inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4)); + inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5)); + + outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c")); + } + + @Test + public void testMultiplePythonJobsParallel()throws Exception{ + PythonContextManager.deleteNonMainContexts(); + + String code1 = "c = a + b"; + PythonJob job1 = new PythonJob("job1", code1, false); + + String code2 = "c = a - b"; + PythonJob job2 = new PythonJob("job2", code2, false); + + PythonVariables inputs = new PythonVariables(); + inputs.addInt("a", 2); + inputs.addInt("b", 3); + + PythonVariables outputs = new PythonVariables(); + outputs.addInt("c"); + + job1.exec(inputs, outputs); + + assertEquals(5L, (long)outputs.getIntValue("c")); + + job2.exec(inputs, outputs); + + assertEquals(-1L, (long)outputs.getIntValue("c")); + + inputs = new PythonVariables(); + inputs.addFloat("a", 3.0); + inputs.addFloat("b", 4.0); + + outputs = new PythonVariables(); + outputs.addFloat("c"); + + + job1.exec(inputs, outputs); + + assertEquals(7.0, outputs.getFloatValue("c"), 1e-5); + + job2.exec(inputs, outputs); + + assertEquals(-1L, outputs.getFloatValue("c"), 1e-5); + + + inputs = new PythonVariables(); + inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4)); + inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5)); + + outputs = new PythonVariables(); + outputs.addNDArray("c"); + + + job1.exec(inputs, outputs); + + assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c")); + + job2.exec(inputs, outputs); + + assertEquals(Nd4j.zeros(3, 2).sub(1), outputs.getNDArrayValue("c")); + } + @Test + public void testPythonJobSetupRun()throws Exception{ + PythonContextManager.deleteNonMainContexts(); + + String code = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n"; + PythonJob job = new PythonJob("job1", code, true); + + PythonVariables inputs = new PythonVariables(); + inputs.addInt("a", 2); + inputs.addInt("b", 3); + + PythonVariables outputs = new PythonVariables(); + outputs.addInt("c"); + + job.exec(inputs, outputs); + + assertEquals(10L, (long)outputs.getIntValue("c")); + + inputs = new PythonVariables(); + inputs.addFloat("a", 3.0); + inputs.addFloat("b", 4.0); + + outputs = new PythonVariables(); + outputs.addFloat("c"); + + + job.exec(inputs, outputs); + + assertEquals(12.0, outputs.getFloatValue("c"), 1e-5); + + + inputs = new PythonVariables(); + inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4)); + inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5)); + + outputs = new PythonVariables(); + outputs.addNDArray("c"); + + + job.exec(inputs, outputs); + + assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c")); + } + @Test + public void testPythonJobSetupRunAndReturnAllVariables()throws Exception{ + PythonContextManager.deleteNonMainContexts(); + + String code = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n"; + PythonJob job = new PythonJob("job1", code, true); + + PythonVariables inputs = new PythonVariables(); + inputs.addInt("a", 2); + inputs.addInt("b", 3); + + + PythonVariables outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(10L, (long)outputs.getIntValue("c")); + + inputs = new PythonVariables(); + inputs.addFloat("a", 3.0); + inputs.addFloat("b", 4.0); + + outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(12.0, outputs.getFloatValue("c"), 1e-5); + + + inputs = new PythonVariables(); + inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4)); + inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5)); + + outputs = job.execAndReturnAllVariables(inputs); + + assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c")); + } + + @Test + public void testMultiplePythonJobsSetupRunParallel()throws Exception{ + PythonContextManager.deleteNonMainContexts(); + + String code1 = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b + five\n"+ + " return {'c':c}\n\n"; + PythonJob job1 = new PythonJob("job1", code1, true); + + String code2 = "five=None\n" + + "def setup():\n" + + " global five\n"+ + " five = 5\n\n" + + "def run(a, b):\n" + + " c = a + b - five\n"+ + " return {'c':c}\n\n"; + PythonJob job2 = new PythonJob("job2", code2, true); + + PythonVariables inputs = new PythonVariables(); + inputs.addInt("a", 2); + inputs.addInt("b", 3); + + PythonVariables outputs = new PythonVariables(); + outputs.addInt("c"); + + job1.exec(inputs, outputs); + + assertEquals(10L, (long)outputs.getIntValue("c")); + + job2.exec(inputs, outputs); + + assertEquals(0L, (long)outputs.getIntValue("c")); + + inputs = new PythonVariables(); + inputs.addFloat("a", 3.0); + inputs.addFloat("b", 4.0); + + outputs = new PythonVariables(); + outputs.addFloat("c"); + + + job1.exec(inputs, outputs); + + assertEquals(12.0, outputs.getFloatValue("c"), 1e-5); + + job2.exec(inputs, outputs); + + assertEquals(2L, outputs.getFloatValue("c"), 1e-5); + + + inputs = new PythonVariables(); + inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4)); + inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5)); + + outputs = new PythonVariables(); + outputs.addNDArray("c"); + + + job1.exec(inputs, outputs); + + assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c")); + + job2.exec(inputs, outputs); + + assertEquals(Nd4j.zeros(3, 2).add(4), outputs.getNDArrayValue("c")); + } + +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonList.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonList.java new file mode 100644 index 000000000..7362a5e49 --- /dev/null +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonList.java @@ -0,0 +1,108 @@ + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.datavec.python; + + +import lombok.var; +import org.json.JSONArray; +import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.*; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +@javax.annotation.concurrent.NotThreadSafe +public class TestPythonList { + + @Test + public void testPythonListFromIntArray() { + PythonObject pyList = new PythonObject(new Integer[]{1, 2, 3, 4, 5}); + pyList.attr("append").call(6); + pyList.attr("append").call(7); + pyList.attr("append").call(8); + assertEquals(8, Python.len(pyList).toInt()); + for (int i = 0; i < 8; i++) { + assertEquals(i + 1, pyList.get(i).toInt()); + } + + } + + @Test + public void testPythonListFromLongArray() { + PythonObject pyList = new PythonObject(new Long[]{1L, 2L, 3L, 4L, 5L}); + pyList.attr("append").call(6); + pyList.attr("append").call(7); + pyList.attr("append").call(8); + assertEquals(8, Python.len(pyList).toInt()); + for (int i = 0; i < 8; i++) { + assertEquals(i + 1, pyList.get(i).toInt()); + } + + } + + @Test + public void testPythonListFromDoubleArray() { + PythonObject pyList = new PythonObject(new Double[]{1., 2., 3., 4., 5.}); + pyList.attr("append").call(6); + pyList.attr("append").call(7); + pyList.attr("append").call(8); + assertEquals(8, Python.len(pyList).toInt()); + for (int i = 0; i < 8; i++) { + assertEquals(i + 1, pyList.get(i).toInt()); + assertEquals((double) i + 1, pyList.get(i).toDouble(), 1e-5); + } + + } + + @Test + public void testPythonListFromStringArray() { + PythonObject pyList = new PythonObject(new String[]{"abcd", "efg"}); + pyList.attr("append").call("hijk"); + pyList.attr("append").call("lmnop"); + assertEquals("abcdefghijklmnop", new PythonObject("").attr("join").call(pyList).toString()); + } + + @Test + public void testPythonListFromMixedArray()throws Exception { + Map map = new HashMap<>(); + map.put(1, "a"); + map.put("a", Arrays.asList("a", "b", "c")); + map.put("arr", Nd4j.linspace(1, 4, 4)); + Object[] objs = new Object[]{ + 1, 2, "a", 3f, 4L, 5.0, Arrays.asList(10, + 20, "b", 30f, 40L, 50.0, map + + ), map + }; + PythonObject pyList = new PythonObject(objs); + System.out.println(pyList.toString()); + String expectedStr = "[1, 2, 'a', 3.0, 4, 5.0, [10" + + ", 20, 'b', 30.0, 40, 50.0, {'arr': array([1.," + + " 2., 3., 4.], dtype=float32), 1: 'a', 'a': [" + + "'a', 'b', 'c']}], {'arr': array([1., 2., 3.," + + " 4.], dtype=float32), 1: 'a', 'a': ['a', 'b', 'c']}]"; + assertEquals(expectedStr, pyList.toString()); + List objs2 = pyList.toList(); + PythonObject pyList2 = new PythonObject(objs2); + assertEquals(pyList.toString(), pyList2.toString()); + } + +} diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonSetupAndRun.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonSetupAndRun.java deleted file mode 100644 index 42a22c07f..000000000 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonSetupAndRun.java +++ /dev/null @@ -1,27 +0,0 @@ -package org.datavec.python; - -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -@javax.annotation.concurrent.NotThreadSafe -public class TestPythonSetupAndRun { - @Test - public void testPythonWithSetupAndRun() throws Exception{ - String code = "def setup():" + - "global counter;counter=0\n" + - "def run(step):" + - "global counter;" + - "counter+=step;" + - "return {\"counter\":counter}"; - PythonVariables pyInputs = new PythonVariables(); - pyInputs.addInt("step", 2); - PythonVariables pyOutputs = new PythonVariables(); - pyOutputs.addInt("counter"); - PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs); - assertEquals((long)pyOutputs.getIntValue("counter"), 2L); - pyInputs.addInt("step", 3); - PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs); - assertEquals((long)pyOutputs.getIntValue("counter"), 5L); - } -} \ No newline at end of file diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java index 0f14cb756..22f8ba230 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestPythonVariables.java @@ -22,11 +22,14 @@ package org.datavec.python; +import org.bytedeco.javacpp.BytePointer; import org.junit.Test; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; import java.util.Arrays; import java.util.Collections; +import java.util.List; import static junit.framework.TestCase.assertNotNull; import static junit.framework.TestCase.assertNull; @@ -36,59 +39,50 @@ import static org.junit.Assert.assertTrue; public class TestPythonVariables { - - @Test - public void testImportNumpy(){ - Nd4j.scalar(1.0); - System.out.println(System.getProperty("org.bytedeco.openblas.load")); - PythonExecutioner.exec("import numpy as np"); - } - - - @Test - public void testDataAssociations() { + public void testDataAssociations() throws PythonException{ PythonVariables pythonVariables = new PythonVariables(); - PythonVariables.Type[] types = { - PythonVariables.Type.INT, - PythonVariables.Type.FLOAT, - PythonVariables.Type.STR, - PythonVariables.Type.BOOL, - PythonVariables.Type.DICT, - PythonVariables.Type.LIST, - PythonVariables.Type.LIST, - PythonVariables.Type.FILE, - PythonVariables.Type.NDARRAY + PythonType[] types = { + PythonType.INT, + PythonType.FLOAT, + PythonType.STR, + PythonType.BOOL, + PythonType.DICT, + PythonType.LIST, + PythonType.LIST, + PythonType.NDARRAY, + PythonType.BYTES }; - NumpyArray npArr = new NumpyArray(Nd4j.scalar(1.0)); + INDArray arr = Nd4j.scalar(1.0); + BytePointer bp = new BytePointer(arr.data().pointer()); Object[] values = { 1L,1.0,"1",true, Collections.singletonMap("1",1), - new Object[]{1}, Arrays.asList(1),"type", npArr + new Object[]{1}, Arrays.asList(1), arr, bp }; Object[] expectedValues = { 1L,1.0,"1",true, Collections.singletonMap("1",1), - new Object[]{1}, new Object[]{1},"type", npArr + Arrays.asList(1), Arrays.asList(1), arr, bp }; for(int i = 0; i < types.length; i++) { - testInsertGet(pythonVariables,types[i].name() + i,values[i],types[i],expectedValues[i]); + testInsertGet(pythonVariables,types[i].getName().name() + i,values[i],types[i],expectedValues[i]); } assertEquals(types.length,pythonVariables.getVariables().length); } - private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonVariables.Type type,Object expectedValue) { + private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonType type,Object expectedValue) throws PythonException{ pythonVariables.add(key, type); assertNull(pythonVariables.getValue(key)); pythonVariables.setValue(key,value); assertNotNull(pythonVariables.getValue(key)); Object actualValue = pythonVariables.getValue(key); if (expectedValue instanceof Object[]){ - assertTrue(actualValue instanceof Object[]); - Object[] actualArr = (Object[])actualValue; + assertTrue(actualValue instanceof List); + Object[] actualArr = ((List)actualValue).toArray(); Object[] expectedArr = (Object[])expectedValue; assertArrayEquals(expectedArr, actualArr); } diff --git a/datavec/datavec-python/src/test/java/org/datavec/python/TestSerde.java b/datavec/datavec-python/src/test/java/org/datavec/python/TestSerde.java index cc12174f8..71d37ca91 100644 --- a/datavec/datavec-python/src/test/java/org/datavec/python/TestSerde.java +++ b/datavec/datavec-python/src/test/java/org/datavec/python/TestSerde.java @@ -44,7 +44,7 @@ public class TestSerde { String yaml = y.serialize(t); String json = j.serialize(t); - Transform t2 = y.deserializeTransform(json); + Transform t2 = y.deserializeTransform(yaml); Transform t3 = j.deserializeTransform(json); assertEquals(t, t2); assertEquals(t, t3); From 5d28e6143df90693cace3e3f2428e6c945703845 Mon Sep 17 00:00:00 2001 From: raver119 Date: Wed, 5 Feb 2020 07:27:24 +0300 Subject: [PATCH 2/9] OpContext handling (#214) * nano tweaks Signed-off-by: raver119 * OpContext tweaks Signed-off-by: raver119 * OpContext deallocators Signed-off-by: raver119 * get rid of few mkldnn safety checks Signed-off-by: raver119 * databuffer setSpecial fix Signed-off-by: raver119 --- .../layers/mkldnn/MKLDNNBatchNormHelper.java | 3 +- .../nn/layers/mkldnn/MKLDNNConvHelper.java | 11 ++---- ...KLDNNLocalResponseNormalizationHelper.java | 6 ++-- .../mkldnn/MKLDNNSubsamplingHelper.java | 5 ++- libnd4j/blas/NativeOps.h | 1 + libnd4j/blas/cpu/NativeOps.cpp | 4 +++ libnd4j/blas/cuda/NativeOps.cu | 4 +++ libnd4j/include/array/impl/DataBuffer.cpp | 5 +++ libnd4j/include/graph/Context.h | 7 ++++ libnd4j/include/graph/impl/Context.cpp | 10 ++++++ .../declarable/platform/mkldnn/batchnorm.cpp | 4 --- .../ops/declarable/platform/mkldnn/conv2d.cpp | 4 --- .../ops/declarable/platform/mkldnn/conv3d.cpp | 5 --- .../declarable/platform/mkldnn/deconv2d.cpp | 4 --- .../declarable/platform/mkldnn/deconv3d.cpp | 4 --- .../platform/mkldnn/depthwiseConv2d.cpp | 5 --- .../deallocation/DeallocatorService.java | 2 +- .../nd4j/linalg/api/ops/BaseOpContext.java | 6 ++++ .../org/nd4j/linalg/api/ops/OpContext.java | 5 +++ .../java/org/nd4j/nativeblas/NativeOps.java | 1 + .../ops/executioner/CudaOpContext.java | 32 +++++++++++++++-- .../executioner/CudaOpContextDeallocator.java | 34 +++++++++++++++++++ .../java/org/nd4j/nativeblas/Nd4jCuda.java | 8 +++++ .../nativecpu/buffer/BaseCpuDataBuffer.java | 4 +-- .../cpu/nativecpu/buffer/CpuDeallocator.java | 2 +- .../cpu/nativecpu/buffer/LongBuffer.java | 3 +- .../cpu/nativecpu/ops/CpuOpContext.java | 33 ++++++++++++++++-- .../ops/CpuOpContextDeallocator.java | 34 +++++++++++++++++++ .../java/org/nd4j/nativeblas/Nd4jCpu.java | 8 +++++ .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 25 ++++++++++++++ 30 files changed, 229 insertions(+), 50 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContextDeallocator.java create mode 100644 nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContextDeallocator.java diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java index 2e8c04aa3..027f9d80d 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNBatchNormHelper.java @@ -147,8 +147,7 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper { } //Note: batchnorm op expects rank 1 inputs for mean/var etc, not rank 2 shape [1,x] - context.getInputArrays().clear(); - context.getOutputArrays().clear(); + context.purge(); context.setInputArray(0, x); context.setInputArray(1, m); context.setInputArray(2, v); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java index 244f7c1fc..9bbf4deae 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/mkldnn/MKLDNNConvHelper.java @@ -89,8 +89,7 @@ public class MKLDNNConvHelper implements ConvolutionHelper { INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta}; INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView}; - contextBwd.getInputArrays().clear(); - contextBwd.getOutputArrays().clear(); + contextBwd.purge(); for( int i=0; isetExecutionMode((samediff::ExecutionMode) execMode); } +void ctxPurge(OpaqueContext* ptr) { + ptr->clearFastPath(); +} + nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) { return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed); } diff --git a/libnd4j/blas/cuda/NativeOps.cu b/libnd4j/blas/cuda/NativeOps.cu index d65dcaed5..07ce876ea 100755 --- a/libnd4j/blas/cuda/NativeOps.cu +++ b/libnd4j/blas/cuda/NativeOps.cu @@ -3771,6 +3771,10 @@ void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride) { ptr->setShapeFunctionOverride(reallyOverride); } +void ctxPurge(OpaqueContext* ptr) { + ptr->clearFastPath(); +} + int binaryLevel() { return 0; } diff --git a/libnd4j/include/array/impl/DataBuffer.cpp b/libnd4j/include/array/impl/DataBuffer.cpp index 49527026c..36758c684 100644 --- a/libnd4j/include/array/impl/DataBuffer.cpp +++ b/libnd4j/include/array/impl/DataBuffer.cpp @@ -305,12 +305,17 @@ namespace nd4j { if (_primaryBuffer != nullptr && _isOwnerPrimary) { deletePrimary(); } + _primaryBuffer = buffer; _isOwnerPrimary = false; _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); } void DataBuffer::setSpecialBuffer(void *buffer, size_t length) { + if (_specialBuffer != nullptr && _isOwnerSpecial) { + deleteSpecial(); + } + this->setSpecial(buffer, false); _lenInBytes = length * DataTypeUtils::sizeOf(_dataType); } diff --git a/libnd4j/include/graph/Context.h b/libnd4j/include/graph/Context.h index 96b7e1c79..d1e8a4dad 100644 --- a/libnd4j/include/graph/Context.h +++ b/libnd4j/include/graph/Context.h @@ -204,6 +204,13 @@ namespace nd4j { void setBArguments(const std::vector &tArgs); void setDArguments(const std::vector &dArgs); + /** + * This method purges fastpath in/out contents and releases all the handles. + * + * PLEASE NOTE: I/T/B/D args will stay intact + */ + void clearFastPath(); + void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer); void allowHelpers(bool reallyAllow); diff --git a/libnd4j/include/graph/impl/Context.cpp b/libnd4j/include/graph/impl/Context.cpp index 4c7a19133..5add8280d 100644 --- a/libnd4j/include/graph/impl/Context.cpp +++ b/libnd4j/include/graph/impl/Context.cpp @@ -563,6 +563,16 @@ namespace nd4j { for (auto d:dArgs) _dArgs.emplace_back(d); } + + void Context::clearFastPath() { + _fastpath_in.clear(); + _fastpath_out.clear(); + + for (auto v:_handles) + delete v; + + _handles.clear(); + } } } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 8974cef14..0ebee8fbf 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -456,10 +456,6 @@ PLATFORM_IMPL(batchnorm, ENGINE_CPU) { ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(batchnorm, ENGINE_CPU) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - // if (::optimalLevel() < 2) - // return false; - auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw auto mean = INPUT_VARIABLE(1); // [c] auto variance = INPUT_VARIABLE(2); // [c] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index 559edf2cd..1b90812b1 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -265,10 +265,6 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) { } PLATFORM_CHECK(conv2d, ENGINE_CPU) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index 747d84c36..096839d79 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -270,10 +270,6 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { } PLATFORM_CHECK(conv3dnew, ENGINE_CPU) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] @@ -335,7 +331,6 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { } PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index d95052c5a..e63d7440c 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -407,10 +407,6 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) { } PLATFORM_CHECK(deconv2d, ENGINE_CPU) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - // if (::optimalLevel() < 2) - // return false; - auto input = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp index a678e0185..490ce4535 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp @@ -422,10 +422,6 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) { } PLATFORM_CHECK(deconv3d, ENGINE_CPU) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - // if (::optimalLevel() < 2) - // return false; - auto input = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp index fc7a1e9e3..d6722c009 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp @@ -401,10 +401,6 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) { ////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) { - // we don't want to use mkldnn if cpu doesn't support avx/avx2 - if (::optimalLevel() < 2) - return false; - auto input = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; @@ -477,7 +473,6 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { ////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatorService.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatorService.java index dffb93a7b..ded5bc938 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatorService.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/memory/deallocation/DeallocatorService.java @@ -43,7 +43,7 @@ public class DeallocatorService { private Map referenceMap = new ConcurrentHashMap<>(); private List>> deviceMap = new ArrayList<>(); - private AtomicLong counter = new AtomicLong(0); + private final transient AtomicLong counter = new AtomicLong(0); public DeallocatorService() { // we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java index 4a56e2a88..0139a9db5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/BaseOpContext.java @@ -153,4 +153,10 @@ public abstract class BaseOpContext implements OpContext { for (int e = 0; e < arrays.length; e++) setOutputArray(e, arrays[e]); } + + @Override + public void purge() { + fastpath_in.clear(); + fastpath_out.clear(); + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java index 4063746b3..62a4906a7 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/OpContext.java @@ -162,4 +162,9 @@ public interface OpContext extends AutoCloseable { * @param mode */ void setExecutionMode(ExecutionMode mode); + + /** + * This method removes all in/out arrays from this OpContext + */ + void purge(); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java index d284974eb..1a01bf278 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/src/main/java/org/nd4j/nativeblas/NativeOps.java @@ -1161,6 +1161,7 @@ public interface NativeOps { void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow); void ctxSetExecutionMode(OpaqueContext ptr, int execMode); void ctxShapeFunctionOverride(OpaqueContext ptr, boolean reallyOverride); + void ctxPurge(OpaqueContext ptr); void deleteGraphContext(OpaqueContext ptr); OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java index 01127e891..5e26b3ea3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContext.java @@ -23,6 +23,8 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator; import org.nd4j.jita.allocator.pointers.cuda.cudaStream_t; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.concurrency.AffinityManager; +import org.nd4j.linalg.api.memory.Deallocatable; +import org.nd4j.linalg.api.memory.Deallocator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.ExecutionMode; @@ -40,14 +42,19 @@ import org.nd4j.nativeblas.OpaqueRandomGenerator; * CUDA wrapper for op Context * @author raver119@gmail.com */ -public class CudaOpContext extends BaseOpContext implements OpContext { +public class CudaOpContext extends BaseOpContext implements OpContext, Deallocatable { // we might want to have configurable private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); private OpaqueContext context = nativeOps.createGraphContext(1); + private final transient long id = Nd4j.getDeallocatorService().nextValue(); + + public CudaOpContext() { + Nd4j.getDeallocatorService().pickObject(this); + } @Override public void close() { - nativeOps.deleteGraphContext(context); + // no-op } @Override @@ -143,4 +150,25 @@ public class CudaOpContext extends BaseOpContext implements OpContext { super.setExecutionMode(mode); nativeOps.ctxSetExecutionMode(context, mode.ordinal()); } + + @Override + public void purge() { + super.purge(); + nativeOps.ctxPurge(context); + } + + @Override + public String getUniqueId() { + return new String("CTX_" + id); + } + + @Override + public Deallocator deallocator() { + return new CudaOpContextDeallocator(this); + } + + @Override + public int targetDevice() { + return 0; + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContextDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContextDeallocator.java new file mode 100644 index 000000000..62b5e4a00 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaOpContextDeallocator.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.jcublas.ops.executioner; + +import org.nd4j.linalg.api.memory.Deallocator; +import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueContext; + +public class CudaOpContextDeallocator implements Deallocator { + private transient final OpaqueContext context; + + public CudaOpContextDeallocator(CudaOpContext ctx) { + context = (OpaqueContext) ctx.contextPointer(); + } + + @Override + public void deallocate() { + NativeOpsHolder.getInstance().getDeviceNativeOps().deleteGraphContext(context); + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java index f85ae9cf1..e7ddcda11 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/nativeblas/Nd4jCuda.java @@ -3090,6 +3090,7 @@ public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride); public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode); +public native void ctxPurge(OpaqueContext ptr); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); @@ -6453,6 +6454,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs); public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs); + /** + * This method purges fastpath in/out contents and releases all the handles. + * + * PLEASE NOTE: I/T/B/D args will stay intact + */ + public native void clearFastPath(); + public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); public native void allowHelpers(@Cast("bool") boolean reallyAllow); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java index 71583638a..a51666f78 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java @@ -43,7 +43,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo protected transient OpaqueDataBuffer ptrDataBuffer; - private final long instanceId = Nd4j.getDeallocatorService().nextValue(); + private transient final long instanceId = Nd4j.getDeallocatorService().nextValue(); protected BaseCpuDataBuffer() { @@ -52,7 +52,7 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo @Override public String getUniqueId() { - return "BCDB_" + instanceId; + return new String("BCDB_" + instanceId); } @Override diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java index 3b8a46fa6..e808ebaa3 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/CpuDeallocator.java @@ -28,7 +28,7 @@ import org.nd4j.nativeblas.OpaqueDataBuffer; */ @Slf4j public class CpuDeallocator implements Deallocator { - private OpaqueDataBuffer opaqueDataBuffer; + private final transient OpaqueDataBuffer opaqueDataBuffer; public CpuDeallocator(BaseCpuDataBuffer buffer) { opaqueDataBuffer = buffer.getOpaqueDataBuffer(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java index 898a125f2..19ad6f907 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java @@ -28,6 +28,7 @@ import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.pointers.PagedPointer; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueDataBuffer; import java.nio.ByteBuffer; @@ -123,7 +124,7 @@ public class LongBuffer extends BaseCpuDataBuffer { // we still want this buffer to have native representation - ptrDataBuffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(0, DataType.INT64.toInt(), false); + ptrDataBuffer = OpaqueDataBuffer.allocateDataBuffer(0, DataType.INT64, false); NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, this.pointer, numberOfElements); Nd4j.getDeallocatorService().pickObject(this); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java index 461646311..9d79e6545 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContext.java @@ -20,11 +20,14 @@ import lombok.NonNull; import lombok.val; import org.bytedeco.javacpp.*; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.memory.Deallocatable; +import org.nd4j.linalg.api.memory.Deallocator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.BaseOpContext; import org.nd4j.linalg.api.ops.ExecutionMode; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.cpu.nativecpu.buffer.BaseCpuDataBuffer; +import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.primitives.Pair; import org.nd4j.nativeblas.NativeOps; import org.nd4j.nativeblas.NativeOpsHolder; @@ -38,14 +41,19 @@ import java.util.List; * * @author raver119@gmail.com */ -public class CpuOpContext extends BaseOpContext implements OpContext { +public class CpuOpContext extends BaseOpContext implements OpContext, Deallocatable { // we might want to have configurable private NativeOps nativeOps = NativeOpsHolder.getInstance().getDeviceNativeOps(); private OpaqueContext context = nativeOps.createGraphContext(1); + private final transient long id = Nd4j.getDeallocatorService().nextValue(); + + public CpuOpContext() { + Nd4j.getDeallocatorService().pickObject(this); + } @Override public void close() { - nativeOps.deleteGraphContext(context); + // no-op } @Override @@ -136,4 +144,25 @@ public class CpuOpContext extends BaseOpContext implements OpContext { super.setExecutionMode(mode); nativeOps.ctxSetExecutionMode(context, mode.ordinal()); } + + @Override + public void purge() { + super.purge(); + nativeOps.ctxPurge(context); + } + + @Override + public String getUniqueId() { + return new String("CTX_" + id); + } + + @Override + public Deallocator deallocator() { + return new CpuOpContextDeallocator(this); + } + + @Override + public int targetDevice() { + return 0; + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContextDeallocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContextDeallocator.java new file mode 100644 index 000000000..621f882bd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/CpuOpContextDeallocator.java @@ -0,0 +1,34 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.cpu.nativecpu.ops; + +import org.nd4j.linalg.api.memory.Deallocator; +import org.nd4j.nativeblas.NativeOpsHolder; +import org.nd4j.nativeblas.OpaqueContext; + +public class CpuOpContextDeallocator implements Deallocator { + private transient final OpaqueContext context; + + public CpuOpContextDeallocator(CpuOpContext ctx) { + context = (OpaqueContext) ctx.contextPointer(); + } + + @Override + public void deallocate() { + NativeOpsHolder.getInstance().getDeviceNativeOps().deleteGraphContext(context); + } +} diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index 5522141be..b954a4a34 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -3093,6 +3093,7 @@ public native OpaqueRandomGenerator getGraphContextRandomGenerator(OpaqueContext public native void ctxAllowHelpers(OpaqueContext ptr, @Cast("bool") boolean reallyAllow); public native void ctxShapeFunctionOverride(OpaqueContext ptr, @Cast("bool") boolean reallyOverride); public native void ctxSetExecutionMode(OpaqueContext ptr, int execMode); +public native void ctxPurge(OpaqueContext ptr); public native void markGraphContextInplace(OpaqueContext ptr, @Cast("bool") boolean reallyInplace); public native void setGraphContextCudaContext(OpaqueContext ptr, Pointer stream, Pointer reductionPointer, Pointer allocationPointer); public native void setGraphContextInputArray(OpaqueContext ptr, int index, Pointer buffer, Pointer shapeInfo, Pointer specialBuffer, Pointer specialShapeInfo); @@ -6456,6 +6457,13 @@ public native @Cast("bool") boolean isOptimalRequirementsMet(); public native void setDArguments(@Cast("nd4j::DataType*") @StdVector IntBuffer dArgs); public native void setDArguments(@Cast("nd4j::DataType*") @StdVector int[] dArgs); + /** + * This method purges fastpath in/out contents and releases all the handles. + * + * PLEASE NOTE: I/T/B/D args will stay intact + */ + public native void clearFastPath(); + public native void setCudaContext(@Cast("Nd4jPointer") Pointer cudaStream, @Cast("Nd4jPointer") Pointer reductionPointer, @Cast("Nd4jPointer") Pointer allocationPointer); public native void allowHelpers(@Cast("bool") boolean reallyAllow); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index d96c0ed31..ad5bacc4e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8262,6 +8262,31 @@ public class Nd4jTestsC extends BaseNd4jTest { assertArrayEquals(new long[]{10, 0}, out2.shape()); } + @Test + public void testDealloc_1() throws Exception { + + for (int e = 0; e < 5000; e++){ + try(val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("someid")) { + val x = Nd4j.createUninitialized(DataType.FLOAT, 1, 1000); + //val y = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 100)).reshape('c', 10, 10); + //val z = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(100, 200)).reshape('c', 10, 10); + //val a = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(200, 300)).reshape('f', 10, 10); + } finally { + //System.gc(); + } + } + + Thread.sleep(1000); + System.gc(); + + Thread.sleep(1000); + System.gc(); + System.gc(); + System.gc(); + + //Nd4j.getMemoryManager().printRemainingStacks(); + } + @Override public char ordering() { return 'c'; From 569a46f87d3aef2a6365eb9740f90aec3e80491d Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 5 Feb 2020 17:07:36 +1100 Subject: [PATCH 3/9] Fixes (#213) * Increase timeouts for 2 tests occasionally failing on CI Signed-off-by: AlexDBlack * Explicitly set character encoding via argline for maven surefire tests Signed-off-by: AlexDBlack * CUDA gradient check timeout fix + simple rnn masking fix Signed-off-by: AlexDBlack --- arbiter/pom.xml | 2 +- datavec/pom.xml | 2 +- .../GradientCheckTestsMasking.java | 32 +++++++++++-------- .../deeplearning4j-dataimport-solrj/pom.xml | 2 +- .../deeplearning4j-modelexport-solr/pom.xml | 2 +- .../pom.xml | 2 +- .../clustering/kmeans/KMeansTest.java | 5 +++ .../nn/layers/recurrent/SimpleRnn.java | 6 ++++ deeplearning4j/pom.xml | 2 +- .../nd4j-backend-impls/nd4j-cuda/pom.xml | 2 +- nd4j/nd4j-backends/nd4j-backend-impls/pom.xml | 2 +- .../nd4j-tests-tensorflow/pom.xml | 4 +-- nd4j/nd4j-backends/nd4j-tests/pom.xml | 2 +- .../java/org/nd4j/autodiff/TestOpMapping.java | 5 +++ nd4j/nd4j-serde/nd4j-aeron/pom.xml | 2 +- nd4j/nd4j-serde/nd4j-arrow/pom.xml | 2 +- nd4j/nd4j-serde/nd4j-gson/pom.xml | 2 +- nd4j/nd4j-serde/nd4j-kryo/pom.xml | 2 +- 18 files changed, 49 insertions(+), 29 deletions(-) diff --git a/arbiter/pom.xml b/arbiter/pom.xml index 364c6d904..93f877968 100644 --- a/arbiter/pom.xml +++ b/arbiter/pom.xml @@ -192,7 +192,7 @@ maven-surefire-plugin ${maven-surefire-plugin.version} - -Ddtype=double -Xmx3024m -Xms3024m + -Ddtype=double -Dfile.encoding=UTF-8 -Xmx3024m -Xms3024m *.java diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml index 02ce30a40..383eb1c8c 100644 --- a/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml +++ b/deeplearning4j/deeplearning4j-modelexport-solr/pom.xml @@ -33,7 +33,7 @@ org.apache.maven.plugins maven-surefire-plugin - -Ddtype=float -Xmx8g -Dtest.solr.allowed.securerandom=NativePRNG + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g -Dtest.solr.allowed.securerandom=NativePRNG *.java diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml index 2d4a4da14..911432cf0 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml @@ -36,7 +36,7 @@ org.apache.maven.plugins maven-surefire-plugin - -Ddtype=float -Xmx8g + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g *.java diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java index 2f2619e78..abbfa04bc 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java @@ -38,6 +38,11 @@ public class KMeansTest extends BaseDL4JTest { private boolean[] useKMeansPlusPlus = {true, false}; + @Override + public long getTimeoutMilliseconds() { + return 60000L; + } + @Test public void testKMeans() { Nd4j.getRandom().setSeed(7); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java index 044f444c0..87d88efcb 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/recurrent/SimpleRnn.java @@ -282,6 +282,12 @@ public class SimpleRnn extends BaseRecurrentLayer true false - -Ddtype=float -Xmx8g + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g *.java diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml index d98c7a6d1..36f25d636 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/pom.xml @@ -60,7 +60,7 @@ Maximum heap size was set to 6g, as a minimum required value for tests run. Depending on a build machine, default value is not always enough. --> - -Ddtype=float -Xmx8g + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/pom.xml b/nd4j/nd4j-backends/nd4j-backend-impls/pom.xml index 86ce07ff7..c6da5e6f0 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/pom.xml +++ b/nd4j/nd4j-backends/nd4j-backend-impls/pom.xml @@ -117,7 +117,7 @@ Maximum heap size was set to 8g, as a minimum required value for tests run. Depending on a build machine, default value is not always enough. --> - -Ddtype=float -Xmx8g + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g diff --git a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml index 5f5c5fa90..6a3cc6eda 100644 --- a/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests-tensorflow/pom.xml @@ -224,7 +224,7 @@ Depending on a build machine, default value is not always enough. --> false - -Xmx6g + -Xmx6g -Dfile.encoding=UTF-8 @@ -296,7 +296,7 @@ Maximum heap size was set to 6g, as a minimum required value for tests run. Depending on a build machine, default value is not always enough. --> - -Xmx6g + -Xmx6g -Dfile.encoding=UTF-8 false false diff --git a/nd4j/nd4j-backends/nd4j-tests/pom.xml b/nd4j/nd4j-backends/nd4j-tests/pom.xml index 9d098189f..8d250629a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/pom.xml +++ b/nd4j/nd4j-backends/nd4j-tests/pom.xml @@ -252,7 +252,7 @@ Maximum heap size was set to 6g, as a minimum required value for tests run. Depending on a build machine, default value is not always enough. --> - -Ddtype=float -Xmx6g + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx6g diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java index ab56ae281..e88f195c0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java @@ -44,6 +44,11 @@ public class TestOpMapping extends BaseNd4jTest { return 'c'; } + @Override + public long getTimeoutMilliseconds() { + return 60000L; + } + @Test public void testOpMappingCoverage() throws Exception { Reflections reflections = new Reflections("org.nd4j"); diff --git a/nd4j/nd4j-serde/nd4j-aeron/pom.xml b/nd4j/nd4j-serde/nd4j-aeron/pom.xml index 87e9347dd..827afb23a 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/pom.xml +++ b/nd4j/nd4j-serde/nd4j-aeron/pom.xml @@ -91,7 +91,7 @@ For testing large zoo models, this may not be enough (so comment it out). --> - -Ddtype=float -Xmx8g + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g diff --git a/nd4j/nd4j-serde/nd4j-arrow/pom.xml b/nd4j/nd4j-serde/nd4j-arrow/pom.xml index ddadc2df1..3a768c1a5 100644 --- a/nd4j/nd4j-serde/nd4j-arrow/pom.xml +++ b/nd4j/nd4j-serde/nd4j-arrow/pom.xml @@ -109,7 +109,7 @@ For testing large zoo models, this may not be enough (so comment it out). --> - -Ddtype=float -Xmx8g + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g diff --git a/nd4j/nd4j-serde/nd4j-gson/pom.xml b/nd4j/nd4j-serde/nd4j-gson/pom.xml index f7215436a..f488bfde5 100644 --- a/nd4j/nd4j-serde/nd4j-gson/pom.xml +++ b/nd4j/nd4j-serde/nd4j-gson/pom.xml @@ -100,7 +100,7 @@ For testing large zoo models, this may not be enough (so comment it out). --> - -Ddtype=float -Xmx8g + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g diff --git a/nd4j/nd4j-serde/nd4j-kryo/pom.xml b/nd4j/nd4j-serde/nd4j-kryo/pom.xml index 25acac26f..02970d5e2 100644 --- a/nd4j/nd4j-serde/nd4j-kryo/pom.xml +++ b/nd4j/nd4j-serde/nd4j-kryo/pom.xml @@ -220,7 +220,7 @@ Maximum heap size was set to 6g, as a minimum required value for tests run. Depending on a build machine, default value is not always enough. --> - -Ddtype=float -Xmx6g + -Ddtype=float -Dfile.encoding=UTF-8 -Xmx6g From 5ae40f6e38650ffaf2f41b393b12e0414e058987 Mon Sep 17 00:00:00 2001 From: shugeo Date: Thu, 6 Feb 2020 20:06:50 +0200 Subject: [PATCH 4/9] Shugeo sequence mask fix2 (#216) * Fixed sequence_mask op and tests. Signed-off-by: shugeo * Cuda fix for sequence_mask op. Signed-off-by: shugeo * Fixed sequence_mask op for both platforms and tests. Signed-off-by: shugeo * Fixed solve and triangular_solve for more than 2D for adjoint cases. Signed-off-by: shugeo * Added adjoint solve test again. Signed-off-by: shugeo * Added a set of tests for triangual_solve and generic solve ops. Signed-off-by: shugeo * Added a pair tests for triangular_solve Signed-off-by: shugeo * Added tests for triangular_solve op. Signed-off-by: shugeo --- .../generic/parity_ops/sequence_mask.cpp | 33 +-- .../declarable/helpers/cpu/sequence_mask.cpp | 4 +- .../ops/declarable/helpers/cpu/solve.cpp | 4 +- .../helpers/cpu/triangular_solve.cpp | 9 +- .../declarable/helpers/cuda/sequence_mask.cu | 4 +- .../layers_tests/DeclarableOpsTests11.cpp | 235 ++++++++++++++++++ .../layers_tests/DeclarableOpsTests7.cpp | 62 ++++- 7 files changed, 326 insertions(+), 25 deletions(-) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp index e7694b409..477b298a3 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/sequence_mask.cpp @@ -58,30 +58,31 @@ namespace nd4j { int outRank = shape::rank(in) + 1; auto input = INPUT_VARIABLE(0); auto dtype = DataType::BOOL; - Nd4jLong maxInd = input->argMax(); - Nd4jLong max = input->e(maxInd); + auto argMaxInd = input->argMax(); + Nd4jLong max = input->e(argMaxInd); + Nd4jLong maxInd = max; - if (block.getIArguments()->size() > 0) { - if (block.width() < 2) { - maxInd = INT_ARG(0); - if (maxInd < max) - maxInd = static_cast(max); - if (block.getIArguments()->size() > 1) - dtype = (DataType)INT_ARG(1); - } - else { - dtype = (DataType)INT_ARG(0); - } - } + if (block.numD() > 0) + dtype = D_ARG(0); if (block.width() > 1) { auto maxlen = INPUT_VARIABLE(1); Nd4jLong tmaxlen = maxlen->e(0); if (tmaxlen > max) maxInd = static_cast(tmaxlen); + if (block.numI() > 0) { + dtype = (DataType) INT_ARG(0); + } + } + else { + if (block.numI() > 0) { + maxInd = INT_ARG(0); + } + if (maxInd < max) + maxInd = max; + if (block.numI() > 1) + dtype = (DataType)INT_ARG(1); // to work with legacy code } - else - maxInd = static_cast(max); int lastDimension = maxInd; ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp b/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp index bf3463afe..c175fd96d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/sequence_mask.cpp @@ -38,10 +38,10 @@ namespace helpers { } void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); } - BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); } } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp index 8583d9cba..48f7f0d9a 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/solve.cpp @@ -36,10 +36,12 @@ namespace helpers { static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) { auto inputPart = input->allTensorsAlongDimension({-2, -1}); auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto rows = input->sizeAt(-2); output->assign(input); + auto batchLoop = PRAGMA_THREADS_FOR { for (auto batch = start; batch < stop; batch += increment) { - for (auto r = 0; r < input->rows(); r++) { + for (auto r = 0; r < rows; r++) { for (auto c = 0; c < r; c++) { math::nd4j_swap(outputPart[batch]->t(r, c) , outputPart[batch]->t(c, r)); } diff --git a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp index e904d219c..ceb228439 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/triangular_solve.cpp @@ -108,17 +108,20 @@ namespace helpers { static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) { auto inputPart = input->allTensorsAlongDimension({-2, -1}); auto outputPart = output->allTensorsAlongDimension({-2, -1}); + auto cols = input->sizeAt(-1); + auto rows = input->sizeAt(-2); + auto batchLoop = PRAGMA_THREADS_FOR { for (auto batch = start; batch < stop; batch += increment) { if (!lower) { - for (auto r = 0; r < input->rows(); r++) { + for (auto r = 0; r < rows; r++) { for (auto c = 0; c <= r; c++) { outputPart[batch]->t(r, c) = inputPart[batch]->t(c, r); } } } else { - for (auto r = 0; r < input->rows(); r++) { - for (auto c = r; c < input->columns(); c++) { + for (auto r = 0; r < rows; r++) { + for (auto c = r; c < cols; c++) { outputPart[batch]->t(r, c) = inputPart[batch]->t(c, r); } } diff --git a/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu b/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu index c07db1b95..6b33a384e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/sequence_mask.cu @@ -55,10 +55,10 @@ namespace helpers { } void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) { - BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES); + BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); } - BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES); + BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED); } } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp index de4bdc31b..465703768 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -1667,6 +1667,241 @@ TEST_F(DeclarableOpsTests11, Solve_Test_4) { ASSERT_TRUE(exp.equalsTo(z)); delete res; } +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_1) { + + auto a = NDArrayFactory::create('c', {2, 2, 2}, { + 0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f + }); + + auto b = NDArrayFactory::create('c', {2, 2, 2}, { + 0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f + }); + + auto exp = NDArrayFactory::create('c', {2, 2, 2}, { + 1.3357621f, 0.3399364f, -0.37077796f, 0.91573375f, + 0.4400987f, 0.2766527f, 0.6394467f, 0.79696566f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printBuffer("4 Solve 4x4"); +// exp.printBuffer("4 Expec 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_2) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.99088347f, 1.1917052f, 1.2642528f, + 0.35071516f, 0.50630623f, 0.42935497f, + -0.30013534f, -0.53690606f, -0.47959247f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {true, false}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printBuffer("4_2 Triangular_Solve 3x3"); +// exp.printBuffer("4_2 Triangular_Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_3) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.45400196f, 0.53174824f, 0.62064564f, + -0.79585856f, -0.82621557f, -0.87855506f, + 1.1904413f, 1.3938838f, 1.3926021f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {true, true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printBuffer("4_3 Triangular_Solve 3x3"); +// exp.printBuffer("4_3 Triangular_Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_4) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.8959121f, 1.6109066f, 1.7501404f, + 0.49000582f, 0.66842675f, 0.5577021f, + -0.4398522f, -1.1899745f, -1.1392052f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printBuffer("4_4 Solve 3x3"); +// exp.printBuffer("4_4 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_5) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 1.5504692f, 1.8953944f, 2.2765768f, + 0.03399149f, 0.2883001f, 0.5377323f, + -0.8774802f, -1.2155888f, -1.8049058f + }); + + nd4j::ops::solve op; + + auto res = op.evaluate({&a, &b}, {true, true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + +// z->printBuffer("4_5 Solve 3x3"); +// exp.printBuffer("4_5 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_6) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.99088347f, 1.1917052f, 1.2642528f, + -0.426483f, -0.42840624f, -0.5622601f, + 0.01692283f, -0.04538865f, -0.09868701f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {false, true}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + + z->printBuffer("4_6 Solve 3x3"); + exp.printBuffer("4_6 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_7) { + + auto a = NDArrayFactory::create('c', {3, 3}, { +// 0.7788f, 0.2309f, 0.5056f, +// 0.8012f, 0.7271f, 0.8925f, +// 0.7244f, 0.1804f, 0.5461f + + 0.7788f, 0.2309f, 0.5056f, + 0.8012f, 0.7271f, 0.8925f, + 0.7244f, 0.1804f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.99088347f, 1.1917052f, 1.2642528f, + -0.426483f, -0.42840624f, -0.5622601f, + 0.01692283f, -0.04538865f, -0.09868701f + }); + + nd4j::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {true, false}); + ASSERT_EQ(res->status(), ND4J_STATUS_OK); + auto z = res->at(0); + + z->printBuffer("4_7 Solve 3x3"); + exp.printBuffer("4_7 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + delete res; +} //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests11, Solve_Test_5) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp index 39761ecb3..0a6f8e5e8 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -758,7 +758,7 @@ TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) { TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { auto input = NDArrayFactory::create('c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); - auto exp = NDArrayFactory::create('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + auto exp = NDArrayFactory::create('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, @@ -802,6 +802,66 @@ TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) { delete result; } +TEST_F(DeclarableOpsTests7, Test_SequenceMask_3) { + auto input = NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); + auto exp = NDArrayFactory::create('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + nd4j::ops::sequence_mask op; + auto result = op.evaluate({&input}, {nd4j::DataType::INT32}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printBuffer("Output"); +// z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +TEST_F(DeclarableOpsTests7, Test_SequenceMask_4) { + auto input = NDArrayFactory::create({1, 3, 2}); + auto maxLen = NDArrayFactory::create(5); + auto exp = NDArrayFactory::create('c', {3,5}, { + 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f + }); + + nd4j::ops::sequence_mask op; + auto result = op.evaluate({&input, &maxLen}, {nd4j::DataType::FLOAT32}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printBuffer("Output"); +// z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + +TEST_F(DeclarableOpsTests7, Test_SequenceMask_5) { + auto input = NDArrayFactory::create({1, 3, 2}); + auto exp = NDArrayFactory::create('c', {3,5}, { + 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f + }); + + nd4j::ops::sequence_mask op; + auto result = op.evaluate({&input}, {5, (int)nd4j::DataType::FLOAT32}); + ASSERT_EQ(Status::OK(), result->status()); + + auto z = result->at(0); +// z->printBuffer("Output"); +// z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + //////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests7, TestSegmentMax_1) { auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); From 948646b32de6819be3e561e4fe01911908943fff Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Thu, 6 Feb 2020 20:12:54 +0200 Subject: [PATCH 5/9] Shyrma mkl test (#211) * - provide nhwc format in mkl conv ops Signed-off-by: Yurii * - corrections in mkl conv3d Signed-off-by: Yurii * - corrections in mkl batchnorm Signed-off-by: Yurii * - corrections in mkl maxpooling2d Signed-off-by: Yurii * - add format format_tag::any to outputs in mkl conv ops Signed-off-by: Yurii * - complete corrections in mkl conv ops Signed-off-by: Yurii * - add test for comparison of execution speeds of mkl conv2d op with different weights format Signed-off-by: Yurii * - take into account order f in mkl conv ops Signed-off-by: Yurii --- .../generic/nn/pooling/maxpool3d.cpp | 8 +- .../declarable/platform/cudnn/cudnnUtils.cu | 2 +- .../platform/mkldnn/avgpooling2d.cpp | 186 ++------ .../platform/mkldnn/avgpooling3d.cpp | 200 +++++---- .../platform/mkldnn/avgpooling3d_bp.cpp | 154 ------- .../declarable/platform/mkldnn/batchnorm.cpp | 199 +++++---- .../ops/declarable/platform/mkldnn/conv2d.cpp | 390 +++++++++++++++-- .../ops/declarable/platform/mkldnn/conv3d.cpp | 326 +++++++++++++- .../declarable/platform/mkldnn/deconv2d.cpp | 167 +++---- .../platform/mkldnn/deconv2d_tf.cpp | 74 ++-- .../declarable/platform/mkldnn/deconv3d.cpp | 183 ++++---- .../platform/mkldnn/depthwiseConv2d.cpp | 89 ++-- .../platform/mkldnn/maxpooling2d.cpp | 223 ++-------- .../platform/mkldnn/maxpooling3d.cpp | 223 ++-------- .../platform/mkldnn/mkldnnUtils.cpp | 411 ++++++++++++++---- .../declarable/platform/mkldnn/mkldnnUtils.h | 22 +- .../layers_tests/DeclarableOpsTests1.cpp | 115 ----- .../layers_tests/DeclarableOpsTests4.cpp | 155 ++++++- .../layers_tests/PlaygroundTests.cpp | 52 +-- 19 files changed, 1737 insertions(+), 1442 deletions(-) delete mode 100644 libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp diff --git a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp index b82d5306a..be905e22f 100644 --- a/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp +++ b/libnd4j/include/ops/declarable/generic/nn/pooling/maxpool3d.cpp @@ -169,8 +169,8 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) { // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes @@ -178,8 +178,8 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) { std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "MAXPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); if(!isNCDHW) { input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] diff --git a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu index fa7b1ecfa..02a302e61 100644 --- a/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu +++ b/libnd4j/include/ops/declarable/platform/cudnn/cudnnUtils.cu @@ -250,7 +250,7 @@ void pooling3dCUDNN(const LaunchContext* context, auto handle = reinterpret_cast(context->getCuDnnHandle()); cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream()); if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: can't set stream for cuDNN", err); -printf("fffffffffff\n"); + const int numDims = 5; int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp index 1c1e9d6a4..4c8a582f0 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling2d.cpp @@ -17,6 +17,7 @@ // // @author saudet // @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -36,103 +37,44 @@ namespace platforms { ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", - input->rankOf()); + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto argI = *(block.getIArguments()); - auto output = OUTPUT_VARIABLE(0); const auto kH = INT_ARG(0); const auto kW = INT_ARG(1); const auto sH = INT_ARG(2); const auto sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); const auto dH = INT_ARG(6); const auto dW = INT_ARG(7); - const auto isSameMode = static_cast(INT_ARG(8)); + const auto paddingMode = INT_ARG(8); const auto extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", - dH, dW); + REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - int oH = 0; - int oW = 0; + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); - const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); - - if (!isNCHW) { - input = new NDArray( - input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray( - output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - if (isSameMode) + if (paddingMode) ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - const int bS = input->sizeAt(0); - const int iC = input->sizeAt(1); - const int oC = output->sizeAt(1); + auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; - auto poolingMode = PoolingType::AVG_POOL; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, - true, - bS, iC, iH, iW, oC, oH, oW, input, nullptr, output, - algorithm, - &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, - &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, - pool_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); - auto pool_src_memory = user_src_memory; - dnnl::stream stream(engine); - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - auto pool_dst_memory = user_dst_memory; - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - } - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}}); - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); - } - stream.wait(); - - //streams[0].submitAndWait(); - - if (!isNCHW) { - delete input; - delete output; - } + mkldnnUtils::poolingMKLDNN(input, output, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, mode); return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(avgpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); @@ -141,12 +83,10 @@ PLATFORM_CHECK(avgpool2d, ENGINE_CPU) { ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto gradO = INPUT_VARIABLE( - 1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon + + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon int kH = INT_ARG(0); // filter(kernel) height int kW = INT_ARG(1); // filter(kernel) width @@ -156,92 +96,26 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) { int pW = INT_ARG(5); // paddings width int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME int extraParam0 = INT_ARG(9); - int isNCHW = - block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - REQUIRE_TRUE(input->rankOf() == 4, 0, - "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, - "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, - indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - std::string expectedGradOShape = ShapeUtils::shapeAsString( - ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1})); - std::string expectedGradIShape = ShapeUtils::shapeAsString( - ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, - "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", - expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, - "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", - expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - if (!isNCHW) { - input = new NDArray(input->permute( - {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute( - {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute( - {0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - if (isSameMode) // SAME + if(paddingMode) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - auto poolingMode = PoolingType::AVG_POOL; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, - true, - bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm, - &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, - &user_diff_src_md, &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, - input->buffer() != nullptr ? pool_src_md : pool_diff_src_md, - pool_dst_md, pool_strides, pool_kernel, pool_padding, - pool_padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, - pool_kernel, pool_padding, pool_padding_r); - auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer()); - auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); - auto poolB_src_memory = userB_src_memory; - dnnl::stream stream(engine); - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); - } - auto poolB_dst_memory = userB_dst_memory; - if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); - reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); - } - pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, - {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); - } - stream.wait(); - - if (!isNCHW) { - delete input; - delete gradI; - delete gradO; - } + auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; + mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, mode); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp index 2456625ef..39e85de98 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d.cpp @@ -17,6 +17,7 @@ // // @author saudet // @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -29,113 +30,110 @@ using namespace dnnl; -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE( - 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) +namespace nd4j { +namespace ops { +namespace platforms { - int kD = INT_ARG(0); // filter(kernel) depth - int kH = INT_ARG(1); // filter(kernel) height - int kW = INT_ARG(2); // filter(kernel) width - int sD = INT_ARG(3); // strides depth - int sH = INT_ARG(4); // strides height - int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - int dD = INT_ARG(9); // dilations depth - int dH = INT_ARG(10); // dilations height - int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW +////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) { - REQUIRE_TRUE(input->rankOf() == 5, 0, - "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", - input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, - "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, - indIOioC, indIOioD, indWiC, indWoC, indWkD); + int kD = INT_ARG(0); // filter(kernel) depth + int kH = INT_ARG(1); // filter(kernel) height + int kW = INT_ARG(2); // filter(kernel) width + int sD = INT_ARG(3); // strides depth + int sH = INT_ARG(4); // strides height + int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + int dD = INT_ARG(9); // dilations depth + int dH = INT_ARG(10); // dilations height + int dW = INT_ARG(11); // dilations width + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + int extraParam0 = INT_ARG(13); + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC - std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, - "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", - expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str()); - // REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW); - // REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW); + REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW MKLDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - if (!isNCDHW) { - input = new NDArray( - input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - output = new NDArray( - output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - if (isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + if(paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; + + mkldnnUtils::poolingMKLDNN(input, output, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, mode); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool3dnew, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} + +////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next + auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon + + const int kD = INT_ARG(0); // filter(kernel) depth + const int kH = INT_ARG(1); // filter(kernel) height + const int kW = INT_ARG(2); // filter(kernel) width + const int sD = INT_ARG(3); // strides depth + const int sH = INT_ARG(4); // strides height + const int sW = INT_ARG(5); // strides width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width + const int dD = INT_ARG(9); // dilations depth + const int dH = INT_ARG(10); // dilations height + const int dW = INT_ARG(11); // dilations width + const int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID + const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging + const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC + + REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + + if(paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); + + auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding; + + mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, mode); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); +} - auto poolingMode = PoolingType::AVG_POOL; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, - extraParam0, true, - bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output, - algorithm, - &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, - &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, - pool_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); - auto pool_src_memory = user_src_memory; - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - auto pool_dst_memory = user_dst_memory; - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - } - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}}); - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); - } - stream.wait(); - - if (!isNCDHW) { - delete input; - delete output; - } - - return Status::OK(); - } - - PLATFORM_CHECK(avgpool3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } - } +} +} } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp deleted file mode 100644 index 3fd8ab293..000000000 --- a/libnd4j/include/ops/declarable/platform/mkldnn/avgpooling3d_bp.cpp +++ /dev/null @@ -1,154 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#include -#include -#include - -#include -#include "mkldnnUtils.h" -#include - -using namespace dnnl; - -namespace nd4j { - namespace ops { - namespace platforms { - PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto gradO = INPUT_VARIABLE( - 1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next - auto gradI = OUTPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon - - const int kD = INT_ARG(0); // filter(kernel) depth - const int kH = INT_ARG(1); // filter(kernel) height - const int kW = INT_ARG(2); // filter(kernel) width - const int sD = INT_ARG(3); // strides depth - const int sH = INT_ARG(4); // strides height - const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width - const int dD = INT_ARG(9); // dilations depth - const int dH = INT_ARG(10); // dilations height - const int dW = INT_ARG(11); // dilations width - const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID - int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - - REQUIRE_TRUE(input->rankOf() == 5, 0, - "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, - "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; - int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, - indIOioC, indIOioD, indWiC, indWoC, indWkD); - - std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, - "MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", - expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, - "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", - expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); - - if (!isNCDHW) { - input = new NDArray(input->permute( - {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradI = new NDArray(gradI->permute( - {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradO = new NDArray(gradO->permute( - {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } - - if (isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - - - - auto poolingMode = PoolingType::AVG_POOL; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, - extraParam0, true, - bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO, - algorithm, - &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, - &user_diff_src_md, &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - if (input->buffer() == nullptr) { - pool_src_md = pool_diff_src_md; - user_src_md = user_diff_src_md; - } - auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, - pool_kernel, pool_padding, pool_padding_r); - auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); - auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); - auto poolB_src_memory = userB_src_memory; - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); - } - auto poolB_dst_memory = userB_dst_memory; - if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); - reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); - } - pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, - {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); - } - stream.wait(); - - if (!isNCDHW) { - delete input; - delete gradI; - delete gradO; - } - - return Status::OK(); - } - - PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); - auto output = OUTPUT_VARIABLE(0); - - return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); - } - } - } -} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp index 0ebee8fbf..f63690e81 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/batchnorm.cpp @@ -37,12 +37,12 @@ namespace platforms { ////////////////////////////////////////////////////////////////////////// -static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) { +static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, NDArray* z, + const float epsilon, const bool isNCHW) { - // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) - // also it gives wrong results for formats nhwc and ndhwc + // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x - // x -> 2D:nc, 4D:nchw, 5D:ncdhw + // x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc // mean -> 1D [c] // variance -> 1D [c] // weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta @@ -50,8 +50,6 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray const int xRank = x->rankOf(); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - // input type dnnl::memory::data_type type = dnnl::memory::data_type::f32; @@ -63,17 +61,28 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray dnnl::memory::dims dims; dnnl::memory::format_tag format; + const int indHW = isNCHW ? 2 : 1; + const int bS = x->sizeAt(0); + const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1); + + int iD, iH, iW; + if(xRank == 2) { - dims = {x->sizeAt(0), x->sizeAt(1)}; + dims = {bS, iC}; format = dnnl::memory::format_tag::nc; } else if(xRank == 4) { - dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)}; - format = dnnl::memory::format_tag::nchw; + iH = x->sizeAt(indHW); + iW = x->sizeAt(indHW + 1); + dims = {bS, iC, iH, iW}; + format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; } else { // xRank = 5 - dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)}; - format = dnnl::memory::format_tag::ncdhw; + iD = x->sizeAt(indHW); + iH = x->sizeAt(indHW + 1); + iW = x->sizeAt(indHW + 2); + dims = {bS, iC, iD, iH, iW}; + format = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; } // memory descriptors for arrays @@ -81,29 +90,34 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray // x dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; - x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; - if(xRank > 2) { - x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; - x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3]; + if(x->ews() != 1 || x->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1); + if(xRank > 2) { + x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2); + x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3); + } + if(xRank > 4) + x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4); } - if(xRank > 4) - x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4]; // z, output - dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, format); + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format); - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0]; - z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1]; - if(xRank > 2) { - z_user_md.data.format_desc.blocking.strides[2] = z->stridesOf()[2]; - z_user_md.data.format_desc.blocking.strides[3] = z->stridesOf()[3]; + if(z->ews() != 1 || z->ordering() != 'c') { + z_user_md.data.format_kind = dnnl_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = z->strideAt(0); + z_user_md.data.format_desc.blocking.strides[1] = z->strideAt(1); + if(xRank > 2) { + z_user_md.data.format_desc.blocking.strides[2] = z->strideAt(2); + z_user_md.data.format_desc.blocking.strides[3] = z->strideAt(3); + } + if(xRank > 4) + z_user_md.data.format_desc.blocking.strides[4] = z->strideAt(4); } - if(xRank > 4) - z_user_md.data.format_desc.blocking.strides[4] = z->stridesOf()[4]; + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // batchnorm forward description dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags); @@ -162,12 +176,11 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray ////////////////////////////////////////////////////////////////////////// static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights, - const float epsilon, NDArray* dLdI, NDArray* dLdW) { + NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) { - // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) - // also it gives wrong results for formats nhwc and ndhwc + // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x - // x -> 2D:nc, 4D:nchw, 5D:ncdhw + // x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc // mean -> 1D [c] // variance -> 1D [c] // dLdO - same shape as x @@ -177,8 +190,6 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const const int xRank = x->rankOf(); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - // input type dnnl::memory::data_type type = dnnl::memory::data_type::f32; @@ -190,17 +201,28 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const dnnl::memory::dims dims; dnnl::memory::format_tag format; + const int indHW = isNCHW ? 2 : 1; + const int bS = x->sizeAt(0); + const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1); + + int iD, iH, iW; + if(xRank == 2) { - dims = {x->sizeAt(0), x->sizeAt(1)}; + dims = {bS, iC}; format = dnnl::memory::format_tag::nc; } else if(xRank == 4) { - dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)}; - format = dnnl::memory::format_tag::nchw; + iH = x->sizeAt(indHW); + iW = x->sizeAt(indHW + 1); + dims = {bS, iC, iH, iW}; + format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; } else { // xRank = 5 - dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)}; - format = dnnl::memory::format_tag::ncdhw; + iD = x->sizeAt(indHW); + iH = x->sizeAt(indHW + 1); + iW = x->sizeAt(indHW + 2); + dims = {bS, iC, iD, iH, iW}; + format = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; } // memory descriptors for arrays @@ -208,41 +230,49 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const // x dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format); dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format); - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0]; - x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1]; - if(xRank > 2) { - x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2]; - x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3]; + if(x->ews() != 1 || x->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1); + if(xRank > 2) { + x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2); + x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3); + } + if(xRank > 4) + x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4); } - if(xRank > 4) - x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4]; // dLdO - dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, format); + dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format); - dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format - dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0]; - dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1]; - if(xRank > 2) { - dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->stridesOf()[2]; - dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->stridesOf()[3]; + if(dLdO->ews() != 1 || dLdO->ordering() != 'c') { + dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format + dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->strideAt(0); + dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->strideAt(1); + if(xRank > 2) { + dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->strideAt(2); + dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->strideAt(3); + } + if(xRank > 4) + dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->strideAt(4); } - if(xRank > 4) - dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4]; // dLdI - dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, format); + dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any); dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format); - dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format - dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0]; - dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1]; - if(xRank > 2) { - dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->stridesOf()[2]; - dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->stridesOf()[3]; + if(dLdI->ews() != 1 || dLdI->ordering() != 'c') { + dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format + dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->strideAt(0); + dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->strideAt(1); + if(xRank > 2) { + dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->strideAt(2); + dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->strideAt(3); + } + if(xRank > 4) + dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->strideAt(4); } - if(xRank > 4) - dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4]; + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // batchnorm forward description dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags); @@ -331,7 +361,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const // dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m)) // dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) ) - std::vector axes = {1}; + std::vector axes = isNCHW ? std::vector{1} : std::vector{xRank - 1}; const auto excludedAxes = ShapeUtils::evalDimsToExclude(x->rankOf(), axes); // inversed batch size 1 / N @@ -377,7 +407,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const PLATFORM_IMPL(batchnorm, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc auto mean = INPUT_VARIABLE(1); // [c] auto variance = INPUT_VARIABLE(2); // [c] NDArray* gamma = nullptr; // [c] @@ -436,27 +466,19 @@ PLATFORM_IMPL(batchnorm, ENGINE_CPU) { (*weights)({1,2, 0,0}).assign(0); } - if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc - std::vector permut = inRank == 4 ? std::vector({0,3,1,2}) : std::vector({0,4,1,2,3}); - input = new NDArray(input->permute(permut)); - output = new NDArray(output->permute(permut)); - } + const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); - batchnormMKLDNN(input, mean, variance, weights, epsilon, output); + batchnormMKLDNN(input, mean, variance, weights, output, epsilon, isNCHW); delete weights; - if(axes[0] == inRank - 1 && inRank > 2) { - delete input; - delete output; - } - return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(batchnorm, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + + auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc auto mean = INPUT_VARIABLE(1); // [c] auto variance = INPUT_VARIABLE(2); // [c] NDArray* gamma = nullptr; // [c] @@ -630,7 +652,7 @@ PLATFORM_CHECK(batchnorm, ENGINE_CPU) { ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) { - NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw + NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc NDArray* mean = INPUT_VARIABLE(1); // [c] NDArray* variance = INPUT_VARIABLE(2); // [c] NDArray* gamma = nullptr; // [c] @@ -698,15 +720,9 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) { (*weights)({1,2, 0,0}).assign(0); } + const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2); - if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc - std::vector permut = inRank == 4 ? std::vector({0,3,1,2}) : std::vector({0,4,1,2,3}); - input = new NDArray(input->permute(permut)); - dLdO = new NDArray(dLdO->permute(permut)); - dLdI = new NDArray(dLdI->permute(permut)); - } - - batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, epsilon, dLdI, dLdW); + batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, dLdI, dLdW, epsilon, isNCHW); *dLdM = 0; *dLdV = 0; @@ -721,17 +737,12 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) { delete dLdW; } - if(axes[0] == inRank - 1 && inRank > 2) { - delete input; - delete dLdO; - delete dLdI; - } - return Status::OK(); } ////////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(batchnorm_bp, ENGINE_CPU) { + NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw NDArray* mean = INPUT_VARIABLE(1); // [c] NDArray* variance = INPUT_VARIABLE(2); // [c] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp index 1b90812b1..2d88a73ef 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv2d.cpp @@ -17,6 +17,7 @@ // // @author saudet // @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -33,6 +34,298 @@ namespace nd4j { namespace ops { namespace platforms { +////////////////////////////////////////////////////////////////////// +static void conv2dMKLDNN(const NDArray *input, const NDArray *weights, + const NDArray *bias, NDArray *output, + const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW) { + + // weights [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW] + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = { sH, sW }; + dnnl::memory::dims padding = { pH, pW }; + dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; + dnnl::memory::dims dilation = { dH-1, dW-1}; + + auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + auto type = dnnl::memory::data_type::f32; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + if(input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); + } + + // weights + dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); + + // bias + dnnl::memory::desc b_mkl_md; + if(bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); + if(output->ews() != 1 || output->ordering() != 'c') { + z_user_md.data.format_kind = dnnl_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); + z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); + z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2); + z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3); + } + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); + const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; + + // weights + auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); + const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc(); + auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; + if (wReorder) + dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + + // bias + if(bias != nullptr) { + auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer()); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::convolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); + // shape::printArray(z_mkl_mem.map_data(),8); +} + +////////////////////////////////////////////////////////////////////// +static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, + NDArray *gradI, NDArray *gradW, NDArray *gradB, + const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, + const int paddingMode, const int isNCHW) { + + // weights/gradW [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW] + + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = { sH, sW }; + dnnl::memory::dims padding = { pH, pW }; + dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; + dnnl::memory::dims dilation = { dH-1, dW-1}; + + auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; + + dnnl::memory::dims xDims = {bS, iC, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oH, oW}; + + auto type = dnnl::memory::data_type::f32; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + if(input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); + } + + // weights + dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); + + // gradO + dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); + if(gradO->ews() != 1 || gradO->ordering() != 'c') { + gradO_user_md.data.format_kind = dnnl_blocked; // overrides format + gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); + gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1); + gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2); + gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3); + } + + // gradI + dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + if(gradI->ews() != 1 || gradI->ordering() != 'c') { + gradI_user_md.data.format_kind = dnnl_blocked; // overrides format + gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); + gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1); + gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2); + gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3); + } + + // gradW + dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat); + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(2); + gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); + + // gradB + dnnl::memory::desc gradB_mkl_md; + if(gradB != nullptr) + gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward data primitive description + dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); + const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; + + // weights + auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); + const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); + auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; + if (wReorder) + dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + + // gradO + auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); + const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; + + // gradI + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); + const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); + auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + + // gradW + auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer()); + const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); + auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; + args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + + // gradB + if(gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + if(gradOReorderW || gradOReorderD) + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + + // run backward weights calculations + dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); + + // reorder gradI if necessary + if (gradIReorder) + dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); + if (gradWReorder) + dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); +} + +/* ////////////////////////////////////////////////////////////////////// static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights, const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH, @@ -46,37 +339,37 @@ static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, cons ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); dnnl_memory_desc_t empty; - dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty); + dnnl::memory::desc x_mkl_md(empty), w_mkl_md(empty), b_mkl_md(empty), z_mkl_md(empty); + dnnl::memory::desc x_user_md(empty), w_user_md(empty), b_user_md(empty), z_user_md(empty); - dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; + dnnl::memory::dims strides, padding, padding_r, dilation; mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, bias, output, - &conv_src_md, nullptr, &conv_weights_md, nullptr, - &conv_bias_md, &conv_dst_md, - &user_src_md, nullptr, &user_weights_md, nullptr, - &user_bias_md, &user_dst_md, - conv_strides, conv_padding, conv_padding_r, conv_dilation); + &x_mkl_md, nullptr, &w_mkl_md, nullptr, + &b_mkl_md, &z_mkl_md, + &x_user_md, nullptr, &w_user_md, nullptr, + &b_user_md, &z_user_md, + strides, padding, padding_r, dilation); auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, conv_bias_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r) + algorithm::convolution_auto, x_mkl_md, + w_mkl_md, b_mkl_md, + z_mkl_md, strides, dilation, padding, + padding_r) : convolution_forward::desc(prop_kind::forward, - algorithm::convolution_auto, conv_src_md, - conv_weights_md, - conv_dst_md, conv_strides, conv_dilation, conv_padding, - conv_padding_r); + algorithm::convolution_auto, x_mkl_md, + w_mkl_md, + z_mkl_md, strides, dilation, padding, + padding_r); auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); dnnl::stream stream(engine); auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast(input)->buffer()); - auto user_weights_memory = dnnl::memory(user_weights_md, engine, + auto user_src_memory = dnnl::memory(x_user_md, engine, const_cast(input)->buffer()); + auto user_weights_memory = dnnl::memory(w_user_md, engine, const_cast(weights)->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); + auto user_dst_memory = dnnl::memory(z_user_md, engine, output->buffer()); auto conv_src_memory = user_src_memory; if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) { conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine); @@ -239,13 +532,16 @@ static void conv2dBpMKLDNN(nd4j::graph::Context &block, } } +*/ + ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) - auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always + auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] + + auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW) int sH = INT_ARG(2); // strides height int sW = INT_ARG(3); // strides width @@ -254,16 +550,28 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) { int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast(weights->sizeAt(0)); // filter(kernel) height int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast(weights->sizeAt(1)); // filter(kernel) width - conv2dMKLDNN(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + std::vector expectedWeightsShape = {kH, kW, iC, oC}; + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); + if (bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); return Status::OK(); } + PLATFORM_CHECK(conv2d, ENGINE_CPU) { auto input = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); @@ -276,10 +584,10 @@ PLATFORM_CHECK(conv2d, ENGINE_CPU) { ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { - auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always - auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] - auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next + auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] + auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always @@ -293,19 +601,33 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) { int pW = INT_ARG(5); // paddings width int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME - int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC - REQUIRE_TRUE(input->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",input->rankOf()); - REQUIRE_TRUE(weights->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",weights->rankOf()); - REQUIRE_TRUE(gradO->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",gradO->rankOf()); + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - conv2dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); + int trueoH, trueoW; // true output height, width + ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode); + + if(paddingMode) // SAME + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode); + + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); + std::vector expectedWeightsShape = {kH, kW, iC, oC}; + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); + if(bias) + REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D_BP MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); + + conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); return Status::OK(); } PLATFORM_CHECK(conv2d_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp index 096839d79..7c10b0d1e 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/conv3d.cpp @@ -33,6 +33,314 @@ namespace nd4j { namespace ops { namespace platforms { +////////////////////////////////////////////////////////////////////// +static void conv3dMKLDNN(const NDArray *input, const NDArray *weights, + const NDArray *bias, NDArray *output, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int paddingMode, const int isNCDHW) { + + // weights [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW] + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sD, sH, sW}; + dnnl::memory::dims padding = {pD, pH, pW}; + // dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; + dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; + dnnl::memory::dims dilation = {dD-1, dH-1, dW-1}; + + auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; + + dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; + + auto type = dnnl::memory::data_type::f32; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + if(input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); + x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4); + } + + // weights + dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); + + // bias + dnnl::memory::desc b_mkl_md; + if(bias != nullptr) + b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); + + // output + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); + if(output->ews() != 1 || output->ordering() != 'c') { + z_user_md.data.format_kind = dnnl_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); + z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); + z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2); + z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3); + z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4); + } + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); + const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; + + // weights + auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); + const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc(); + auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem; + if (wReorder) + dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + + // bias + if(bias != nullptr) { + auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer()); + args[DNNL_ARG_BIAS] = b_mkl_mem; + } + + // output + auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::convolution_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); +} + +////////////////////////////////////////////////////////////////////// +static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO, + NDArray *gradI, NDArray *gradW, NDArray *gradB, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int dD, const int dH, const int dW, + const int paddingMode, const int isNCDHW) { + + // weights/gradW [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW] + + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); + + // const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + + dnnl::memory::dims strides = {sD, sH, sW}; + dnnl::memory::dims padding = {pD, pH, pW}; + // dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; + dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW}; + dnnl::memory::dims dilation = {dD-1, dH-1, dW-1}; + + auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; + + dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; + dnnl::memory::dims wDims = {oC, iC, kD, kH, kW}; + dnnl::memory::dims zDims = {bS, oC, oD, oH, oW}; + + auto type = dnnl::memory::data_type::f32; + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + if(input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); + x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4); + } + + // weights + dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat); + w_user_md.data.format_kind = dnnl_blocked; // overrides format + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); + + // gradO + dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); + if(gradO->ews() != 1 || gradO->ordering() != 'c') { + gradO_user_md.data.format_kind = dnnl_blocked; // overrides format + gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); + gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1); + gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2); + gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3); + gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4); + } + + // gradI + dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + if(gradI->ews() != 1 || gradI->ordering() != 'c') { + gradI_user_md.data.format_kind = dnnl_blocked; // overrides format + gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); + gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1); + gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2); + gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3); + gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4); + } + + // gradW + dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat); + gradW_user_md.data.format_kind = dnnl_blocked; // overrides format + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW] + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3); + gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); + gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2); + + // gradB + dnnl::memory::desc gradB_mkl_md; + if(gradB != nullptr) + gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x); + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // forward primitive description + dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward data primitive description + dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc); + + // backward weights primitive description + dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r); + dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); + const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; + + // weights + auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer()); + const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc(); + auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem; + if (wReorder) + dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem); + args[DNNL_ARG_WEIGHTS] = w_mkl_mem; + + // gradO + auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); + const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; + + // gradI + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); + const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); + auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + + // gradW + auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer()); + const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc(); + auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem; + args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem; + + // gradB + if(gradB != nullptr) { + auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer()); + args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem; + } + + // run backward data calculations + dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + + if(gradOReorderW || gradOReorderD) + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + + // run backward weights calculations + dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); + + // reorder gradI if necessary + if (gradIReorder) + dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); + if (gradWReorder) + dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem); + + stream.wait(); + + // shape::printArray(z_mkl_mem.map_data(),8); +} + + +/* ////////////////////////////////////////////////////////////////////// static void conv3dMKLDNN(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights, const NDArray *bias, @@ -225,6 +533,7 @@ static void conv3dBpMKLDNN(nd4j::graph::Context &block, reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory); } } +*/ ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { @@ -256,15 +565,15 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) { int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); - REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); + std::vector expectedWeightsShape = {kD, kH, kW, iC, oC}; + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); if (paddingMode) // SAME ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - conv3dMKLDNN(block, input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); + conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); return Status::OK(); } @@ -280,6 +589,7 @@ PLATFORM_CHECK(conv3dnew, ENGINE_CPU) { ////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] @@ -318,14 +628,14 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) { int trueoD, trueoH, trueoW; // true output depth/height/width ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode); - std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC}); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str()); + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}); + std::vector expectedWeightsShape = {kD, kH, kW, iC, oC}; + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if (bias) REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf()); - conv3dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); + conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp index e63d7440c..1879ef8fb 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d.cpp @@ -34,17 +34,13 @@ namespace platforms { ////////////////////////////////////////////////////////////////////////// static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode) { + const int paddingMode, const bool isNCHW) { - // input [bS, iC, iH, iW] nchw, mkl doesn't support format nhwc - // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC] - // bias [oC], may be nullptr - - // output [bS, oC, oH, oW] nchw, mkl doesn't support format nhwc + // weights [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); dnnl::memory::dims strides = { sH, sW }; dnnl::memory::dims padding = { pH, pW }; @@ -80,8 +76,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N else zType = dnnl::memory::data_type::s32; - - dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; dnnl::memory::dims xDims = {bS, iC, iH, iW}; @@ -93,20 +88,22 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0]; - x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1]; - x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2]; - x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3]; + if(input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); + } // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; - w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; - w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; - w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3]; + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW] + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); // bias dnnl::memory::desc b_mkl_md; @@ -116,11 +113,13 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0]; - z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1]; - z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2]; - z_user_md.data.format_desc.blocking.strides[3] = output->stridesOf()[3]; + if(output->ews() != 1 || output->ordering() != 'c') { + z_user_md.data.format_kind = dnnl_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); + z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); + z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2); + z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3); + } auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -179,21 +178,19 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N ////////////////////////////////////////////////////////////////////////// static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB, const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, - const int paddingMode) { + const int paddingMode, const bool isNCHW) { - // input and gradI [bS, iC, iH, iW], mkl doesn't support ndhwc format - // weights and gradW [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC] - // gradB [oC], may be nullptr - // gradO [bS, oC, oH, oW] + // weights and gradW [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH); dnnl::memory::dims strides = { sH, sW }; dnnl::memory::dims padding = { pH, pW }; dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW }; dnnl::memory::dims dilation = { dH-1, dW-1 }; + // input type dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; // weights type @@ -207,7 +204,7 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const // gradB type dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; dnnl::memory::dims xDims = {bS, iC, iH, iW}; @@ -219,54 +216,59 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0]; - x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1]; - x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2]; - x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3]; + if(input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); + } // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; - w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; - w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; - w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3]; + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW] + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); - gradO_user_md.data.format_kind = dnnl_blocked; // overrides format - gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0]; - gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1]; - gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2]; - gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3]; + if(gradO->ews() != 1 || gradO->ordering() != 'c') { + gradO_user_md.data.format_kind = dnnl_blocked; // overrides format + gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); + gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1); + gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2); + gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3); + } // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); - gradI_user_md.data.format_kind = dnnl_blocked; // overrides format - gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0]; - gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1]; - gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2]; - gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3]; + if(gradI->ews() != 1 || gradI->ordering() != 'c') { + gradI_user_md.data.format_kind = dnnl_blocked; // overrides format + gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); + gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1); + gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2); + gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3); + } // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0]; - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1]; - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2]; - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3]; + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW] + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3); + gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); // gradB dnnl::memory::desc gradB_mkl_md; if(gradB != nullptr) gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x); - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); // forward primitive description @@ -306,11 +308,15 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); - const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorder) - dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem; + const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; // gradI auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); @@ -333,6 +339,9 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const // run backward data calculations dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + if(gradOReorderW || gradOReorderD) + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + // run backward weights calculations dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); @@ -385,23 +394,7 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) { ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); } - // mkl supports only [oC, iC, kH, kW] format for weights - weights = new NDArray(weights->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - - // mkl supports only NCHW - if(!isNCHW) { - input = new NDArray(input->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray(output->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - } - - deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode); - - delete weights; - - if(!isNCHW) { - delete input; - delete output; - } + deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); return Status::OK(); } @@ -477,27 +470,7 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) { ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW); } - // mkl supports only [oC, iC, kH, kW] for weights - weights = new NDArray(weights->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - gradW = new NDArray(gradW->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW] - - // mkl supports NCHW format only - if(!isNCHW) { - input = new NDArray(input->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - } - - deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode); - - delete weights; - delete gradW; - - if(!isNCHW) { - delete input; - delete gradI; - delete gradO; - } + deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp index 90ddb828e..7c6582ab4 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv2d_tf.cpp @@ -33,7 +33,8 @@ namespace platforms { ////////////////////////////////////////////////////////////////////////// static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI, const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW, - const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) { + const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW, + const bool isNCHW) { // gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format // weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC] @@ -51,7 +52,7 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad // gradI type dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16; - dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw; dnnl::memory::dims xDims = {bS, iC, iH, iW}; @@ -67,29 +68,32 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; - w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; - w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; - w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3]; + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); - gradO_user_md.data.format_kind = dnnl_blocked; // overrides format - gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0]; - gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1]; - gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2]; - gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3]; + if(gradO->ews() != 1 || gradO->ordering() != 'c') { + gradO_user_md.data.format_kind = dnnl_blocked; // overrides format + gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); + gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1); + gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2); + gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3); + } // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); - gradI_user_md.data.format_kind = dnnl_blocked; // overrides format - gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0]; - gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1]; - gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2]; - gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3]; - + if(gradI->ews() != 1 || gradI->ordering() != 'c') { + gradI_user_md.data.format_kind = dnnl_blocked; // overrides format + gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); + gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1); + gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2); + gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3); + } auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -166,9 +170,9 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { const int rank = gradO->rankOf(); - REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); - REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf()); - REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf()); + REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF MKLDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf()); + REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF MKLDNN OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf()); + REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF MKLDNN OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf()); int indIOioC, indIiH, indWoC(3), indOoH; if(!isNCHW) { @@ -193,29 +197,29 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) { std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1}); std::vector expectedWeightsShape = {kH, kW, iC, oC}; - REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); + REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str()); if(isSameMode) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - // mkl supports only [oC, iC, kH, kW] for weights - weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, oC] -> [oC, iC, kH, kW] + // // mkl supports only [oC, iC, kH, kW] for weights + // weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, oC] -> [oC, iC, kH, kW] - // mkl supports NCHW format only - if(!isNCHW) { - gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] - } + // // mkl supports NCHW format only + // if(!isNCHW) { + // gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] + // gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW] + // } - deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW); + deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW); - delete weights; + // delete weights; - if(!isNCHW) { - delete gradI; - delete gradO; - } + // if(!isNCHW) { + // delete gradI; + // delete gradO; + // } // ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp index 490ce4535..5daab8228 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/deconv3d.cpp @@ -34,17 +34,14 @@ namespace platforms { ////////////////////////////////////////////////////////////////////////// static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, - const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) { + const int pD, const int pH, const int pW, const int dD, const int dH, const int dW, + const bool isNCDHW) { - // input [bS, iD, iH, iW, iC] ncdhw, mkl doesn't support format ndhwc - // weights [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC] - // bias [oC], may be nullptr - - // output [bS, oD, oH, oW, oC] ncdhw, mkl doesn't support format ndhwc + // weights [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); dnnl::memory::dims strides = { sD, sH, sW }; dnnl::memory::dims padding = { pD, pH, pW }; @@ -80,8 +77,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N else zType = dnnl::memory::data_type::s32; - - dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw; + dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; @@ -93,22 +89,24 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0]; - x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1]; - x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2]; - x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3]; - x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4]; + if(input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); + x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4); + } // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; - w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; - w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; - w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3]; - w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4]; + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); // bias dnnl::memory::desc b_mkl_md; @@ -118,12 +116,14 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat); - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0]; - z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1]; - z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2]; - z_user_md.data.format_desc.blocking.strides[3] = output->stridesOf()[3]; - z_user_md.data.format_desc.blocking.strides[4] = output->stridesOf()[4]; + if(output->ews() !=1 || output->ordering() != 'c') { + z_user_md.data.format_kind = dnnl_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); + z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); + z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2); + z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3); + z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4); + } auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -184,16 +184,14 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, - const int dD, const int dH, const int dW) { + const int dD, const int dH, const int dW, + const bool isNCDHW) { - // input and gradI [bS, iD, iH, iW, iC], mkl doesn't support ndhwc format - // weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC] - // gradB [oC], may be nullptr - // gradO [bS, oD, oH, oW, oC] + // weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD); dnnl::memory::dims strides = { sD, sH, sW }; dnnl::memory::dims padding = { pD, pH, pW }; @@ -213,7 +211,7 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, // gradB type dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw; // isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw; dnnl::memory::dims xDims = {bS, iC, iD, iH, iW}; @@ -225,52 +223,58 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0]; - x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1]; - x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2]; - x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3]; - x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4]; + if(input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); + x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4); + } // weights dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat); w_user_md.data.format_kind = dnnl_blocked; // overrides format - w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0]; - w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1]; - w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2]; - w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3]; - w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4]; + w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] + w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4); + w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0); + w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1); + w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2); // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); - gradO_user_md.data.format_kind = dnnl_blocked; // overrides format - gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0]; - gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1]; - gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2]; - gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3]; - gradO_user_md.data.format_desc.blocking.strides[4] = gradO->stridesOf()[4]; + if(gradO->ews() != 1 || gradO->ordering() != 'c') { + gradO_user_md.data.format_kind = dnnl_blocked; // overrides format + gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); + gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1); + gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2); + gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3); + gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4); + } // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); - gradI_user_md.data.format_kind = dnnl_blocked; // overrides format - gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0]; - gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1]; - gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2]; - gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3]; - gradI_user_md.data.format_desc.blocking.strides[4] = gradI->stridesOf()[4]; + if(gradI->ews() != 1 || gradI->ordering() != 'c') { + gradI_user_md.data.format_kind = dnnl_blocked; // overrides format + gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); + gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1); + gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2); + gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3); + gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4); + } // gradW dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat); dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat); gradW_user_md.data.format_kind = dnnl_blocked; // overrides format - gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0]; - gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1]; - gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2]; - gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3]; - gradW_user_md.data.format_desc.blocking.strides[4] = gradW->stridesOf()[4]; + gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] + gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(4); + gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0); + gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1); + gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2); // gradB dnnl::memory::desc gradB_mkl_md; @@ -317,11 +321,15 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); - const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorder) - dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem; + const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; // gradI auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); @@ -344,6 +352,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights, // run backward data calculations dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + if(gradOReorderW || gradOReorderD) + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + // run backward weights calculations dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); @@ -400,23 +411,7 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) { ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); } - // mkl supports only [oC, iC, kD, kH, kW] format for weights - weights = new NDArray(weights->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] - - // mkl supports only NCDHW - if(!isNCDHW) { - input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - output = new NDArray(output->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] - } - - deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW); - - delete weights; - - if(!isNCDHW) { - delete input; - delete output; - } + deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); return Status::OK(); } @@ -495,27 +490,7 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) { if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - // mkl supports only [oC, iC, kD, kH, kW] for weights - weights = new NDArray(weights->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] - gradW = new NDArray(gradW->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW] - - // mkl supports NCDHW format only - if(!isNCDHW) { - input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradO = new NDArray(gradO->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW] - } - - deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW); - - delete weights; - delete gradW; - - if(!isNCDHW) { - delete input; - delete gradI; - delete gradO; - } + deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp index d6722c009..4da2c2cb0 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/depthwiseConv2d.cpp @@ -86,7 +86,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, else zType = dnnl::memory::data_type::s32; - dnnl::memory::format_tag xzFrmat = dnnl::memory::format_tag::nchw; + dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw; dnnl::memory::dims xDims = {bS, iC, iH, iW}; @@ -98,11 +98,13 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); - x_user_md.data.format_kind = dnnl_blocked; // overrides format NHWC -> NCHW - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 : 3); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2); + if(input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); // do permutation NHWC -> NCHW + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); + } // weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); @@ -122,11 +124,13 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights, // output dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any); dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat); - z_user_md.data.format_kind = dnnl_blocked; // overrides format - z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); - z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 : 3); - z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1); - z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2); + if(output->ews() != 1 || output->ordering() != 'c') { + z_user_md.data.format_kind = dnnl_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); + z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); // do permutation NHWC -> NCHW + z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2); + z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3); + } auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); @@ -219,7 +223,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // gradB type dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32; - dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw; dnnl::memory::dims xDims = {bS, iC, iH, iW}; @@ -230,12 +234,14 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // input dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any); - dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat); - x_user_md.data.format_kind = dnnl_blocked; // overrides format - x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); - x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 : 3); - x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1); - x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat); + if(input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3); + } // weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any); @@ -249,21 +255,25 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // gradO dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat); - gradO_user_md.data.format_kind = dnnl_blocked; // overrides format - gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); - gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 : 3); - gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1); - gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat); + if(gradO->ews() != 1 || gradO->ordering() != 'c') { + gradO_user_md.data.format_kind = dnnl_blocked; // overrides format + gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); + gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1); + gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2); + gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3); + } // gradI dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any); - dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat); - gradI_user_md.data.format_kind = dnnl_blocked; // overrides format - gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); - gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 : 3); - gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1); - gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat); + if(gradI->ews() != 1 || gradI->ordering() != 'c') { + gradI_user_md.data.format_kind = dnnl_blocked; // overrides format + gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); + gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1); + gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2); + gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3); + } // gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW]; dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any); @@ -319,11 +329,15 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // gradO auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); - const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); - auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; - if (gradOReorder) - dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); - args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem; + const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + if (gradOReorderW) + dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW); + if (gradOReorderD) + dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD; // gradI auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); @@ -346,6 +360,9 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w // run backward data calculations dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args); + if(gradOReorderW || gradOReorderD) + args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW; + // run backward weights calculations dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args); @@ -401,6 +418,7 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) { ////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); auto weights = INPUT_VARIABLE(1); auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; @@ -473,6 +491,7 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) { ////////////////////////////////////////////////////////////////////// PLATFORM_CHECK(depthwise_conv2d_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NDHWC) or [bS, iC, iH, iW] (NCDHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, mC] always auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] = [iC*mC] diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp index 69aee8fad..3e7979f2f 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling2d.cpp @@ -17,6 +17,7 @@ // // @author saudet // @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -33,105 +34,38 @@ namespace nd4j { namespace ops { namespace platforms { + ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); - - REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead", - input->rankOf()); - - // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - auto argI = *(block.getIArguments()); auto output = OUTPUT_VARIABLE(0); - const auto kH = INT_ARG(0); - const auto kW = INT_ARG(1); - const auto sH = INT_ARG(2); - const auto sW = INT_ARG(3); - int pH = INT_ARG(4); - int pW = INT_ARG(5); - const auto dH = INT_ARG(6); - const auto dW = INT_ARG(7); - const auto isSameMode = static_cast(INT_ARG(8)); + REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D MKLDNN OP: input array should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}", - dH, dW); + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + const int kH = INT_ARG(0); + const int kW = INT_ARG(1); + const int sH = INT_ARG(2); + const int sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + const int paddingMode = INT_ARG(8); + // const int extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW - int oH = 0; - int oW = 0; + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); - int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - const int iH = static_cast(isNCHW ? input->sizeAt(2) : input->sizeAt(1)); - const int iW = static_cast(isNCHW ? input->sizeAt(3) : input->sizeAt(2)); - - if (!isNCHW) { - input = new NDArray( - input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - output = new NDArray( - output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode); - - if (isSameMode) + if (paddingMode) ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - const int bS = input->sizeAt(0); - const int iC = input->sizeAt(1); - const int oC = output->sizeAt(1); - - auto poolingMode = PoolingType::MAX_POOL; - int extraParam0 = 1; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - - mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, - true, - bS, iC, iH, iW, oC, oH, oW, input, nullptr, output, - algorithm, - &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, - &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, - pool_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); - - auto pool_src_memory = user_src_memory; - dnnl::stream stream(engine); - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - - auto pool_dst_memory = user_dst_memory; - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - } - - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}}); - - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); - } - - stream.wait(); - - if (!isNCHW) { - delete input; - delete output; - } + mkldnnUtils::poolingMKLDNN(input, output, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, algorithm::pooling_max); return Status::OK(); } @@ -159,117 +93,24 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) { int pW = INT_ARG(5); // paddings width int dH = INT_ARG(6); // dilations height int dW = INT_ARG(7); // dilations width - int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME - int extraParam0 = INT_ARG(9); - int isNCHW = - block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME + // int extraParam0 = INT_ARG(9); + int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC - REQUIRE_TRUE(input->rankOf() == 4, 0, - "AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dH != 0 && dW != 0, 0, - "AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, - indIiH, indWiC, indWoC, indWkH, indOoH); + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); - std::string expectedGradOShape = ShapeUtils::shapeAsString( - ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1})); - std::string expectedGradIShape = ShapeUtils::shapeAsString( - ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, - "AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", - expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, - "AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", - expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - - if (!isNCHW) { - input = new NDArray(input->permute( - {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradI = new NDArray(gradI->permute( - {0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW] - gradO = new NDArray(gradO->permute( - {0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW] - } - - if (isSameMode) // SAME + if (paddingMode) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); - auto poolingMode = PoolingType::MAX_POOL; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - - mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0, - true, - bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm, - &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, - &user_diff_src_md, &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - // input is sometimes null, so we can't rely on pool_src_md being valid - auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, - input->buffer() != nullptr ? pool_src_md : pool_diff_src_md, - pool_dst_md, pool_strides, pool_kernel, pool_padding, - pool_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - - auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer()); - auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); - - auto poolB_src_memory = userB_src_memory; - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); - } - - auto poolB_dst_memory = userB_dst_memory; - if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); - reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); - } - - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto pool_src_memory = user_src_memory; - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - - auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine); - - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}, - {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); - // probably wrong, fix that - pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, - {DNNL_ARG_WORKSPACE, pool_workspace_memory}, - {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); - - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); - } - - stream.wait(); - - if (!isNCHW) { - delete input; - delete gradI; - delete gradO; - } + mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, algorithm::pooling_max); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp index a37422c55..7f6e95418 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/maxpooling3d.cpp @@ -16,6 +16,7 @@ // // @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com) // #include @@ -34,10 +35,9 @@ namespace platforms { ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) { - auto input = INPUT_VARIABLE( - 0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) - auto output = OUTPUT_VARIABLE( - 0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) + + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) + auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW) int kD = INT_ARG(0); // filter(kernel) depth int kH = INT_ARG(1); // filter(kernel) height @@ -51,95 +51,24 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) { int dD = INT_ARG(9); // dilations depth int dH = INT_ARG(10); // dilations height int dW = INT_ARG(11); // dilations width - int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - REQUIRE_TRUE(input->rankOf() == 5, 0, - "MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !", - input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, - "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, - indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0, - "MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !", - expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str()); - // REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW); - // REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW); + if(paddingMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - if (!isNCDHW) { - input = new NDArray( - input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - output = new NDArray( - output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } - - if (isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, - dW); - - - auto poolingMode = PoolingType::MAX_POOL; - auto extraParam0 = 1; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - - mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, - extraParam0, true, - bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output, - algorithm, - &pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr, - &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md, - pool_dst_md, pool_strides, pool_kernel, pool_padding, - pool_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); - - auto pool_src_memory = user_src_memory; - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - - auto pool_dst_memory = user_dst_memory; - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - } - - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}}); - - if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) { - reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory); - } - - stream.wait(); - - - if (!isNCDHW) { - delete input; - delete output; - } + mkldnnUtils::poolingMKLDNN(input, output, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, algorithm::pooling_max); return Status::OK(); + } ////////////////////////////////////////////////////////////////////////// @@ -152,6 +81,7 @@ PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) { ////////////////////////////////////////////////////////////////////////// PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW) auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon @@ -162,127 +92,30 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) { const int sD = INT_ARG(3); // strides depth const int sH = INT_ARG(4); // strides height const int sW = INT_ARG(5); // strides width - int pD = INT_ARG(6); // paddings depth - int pH = INT_ARG(7); // paddings height - int pW = INT_ARG(8); // paddings width + int pD = INT_ARG(6); // paddings depth + int pH = INT_ARG(7); // paddings height + int pW = INT_ARG(8); // paddings width const int dD = INT_ARG(9); // dilations depth const int dH = INT_ARG(10); // dilations height const int dW = INT_ARG(11); // dilations width - const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID + const int paddngMode = INT_ARG(12); // 1-SAME, 0-VALID // int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases - int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW + int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW - REQUIRE_TRUE(input->rankOf() == 5, 0, - "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf()); - REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, - "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); + REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW); - int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; + int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width; int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes - ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, - indIOioC, indIOioD, indWiC, indWoC, indWkD); + ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD); - std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( - {bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2})); - REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, - "MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", - expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, - "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", - expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str()); + std::vector expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}); + REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str()); - if (!isNCDHW) { - input = new NDArray(input->permute( - {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradI = new NDArray(gradI->permute( - {0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW] - gradO = new NDArray(gradO->permute( - {0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW] - } + if(paddngMode) // SAME + ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW); - if (isSameMode) // SAME - ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, - dW); - - - auto poolingMode = PoolingType::MAX_POOL; - auto extraParam0 = 1; - - dnnl_memory_desc_t empty; - dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty); - dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty); - dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r; - dnnl::algorithm algorithm; - - mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode, - extraParam0, true, - bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO, - algorithm, - &pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md, - &user_diff_src_md, &user_dst_md, - pool_strides, pool_kernel, pool_padding, pool_padding_r); - - // input is sometimes null, so we can't rely on pool_src_md being valid - if (input->buffer() == nullptr) { - pool_src_md = pool_diff_src_md; - user_src_md = user_diff_src_md; - } - auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); - dnnl::stream stream(engine); - auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine); - - auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r); - - auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc); - auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer()); - auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer()); - - auto poolB_src_memory = userB_src_memory; - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine); - } - - auto poolB_dst_memory = userB_dst_memory; - if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) { - poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine); - reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory); - } - - - auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); - - auto pool_src_memory = user_src_memory; - if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) { - pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine); - reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory); - } - - auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine); - auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine); - - pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory}, - {DNNL_ARG_DST, pool_dst_memory}, - {DNNL_ARG_WORKSPACE, pool_workspace_memory}}); - pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory}, - {DNNL_ARG_WORKSPACE, pool_workspace_memory}, - {DNNL_ARG_DIFF_SRC, poolB_src_memory}}); - - - if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) { - reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory); - } - - stream.wait(); - - if (!isNCDHW) { - delete input; - delete gradI; - delete gradO; - } + mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, algorithm::pooling_max); return Status::OK(); } diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index 0b81de76d..02bba4300 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -16,9 +16,11 @@ // // @author saudet +// @author Yurii Shyrma (iuriish@yahoo.com) // #include +#include #include "mkldnnUtils.h" using namespace dnnl; @@ -26,6 +28,314 @@ using namespace dnnl; namespace nd4j { namespace mkldnnUtils { +////////////////////////////////////////////////////////////////////// +void poolingMKLDNN(const NDArray *input, NDArray *output, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int isNCHW, const dnnl::algorithm mode) { + + // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for input + const int rank = input->rankOf(); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; + dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims; + dnnl::memory::format_tag xzFrmat; + + const auto type = dnnl::memory::data_type::f32; + + if(rank == 4) { // 2d + + ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + strides = { sH, sW }; + kernel = { kH, kW }; + padding = { pH, pW }; + padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; + xDims = {bS, iC, iH, iW}; + zDims = {bS, oC, oH, oW}; + + xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + } + else { // 3d + + ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); + + strides = { sD, sH, sW }; + kernel = { kD, kH, kW }; + padding = { pD, pH, pW }; + padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; + xDims = {bS, iC, iD, iH, iW}; + zDims = {bS, oC, oD, oH, oW}; + + xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + } + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + if(input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2); + if(rank == 5) + x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3); + } + + // output + dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat); + if(output->ews() != 1 || output->ordering() != 'c') { + z_user_md.data.format_kind = dnnl_blocked; // overrides format + z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0); + z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 :-1); + z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1); + z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2); + if(rank == 5) + z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(isNCHW ? 4 : 3); + } + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + + // operation primitive description + dnnl::pooling_forward::desc op_desc(dnnl::prop_kind::forward_inference, mode, x_mkl_md, z_mkl_md, strides, kernel, padding, padding_r); + dnnl::pooling_forward::primitive_desc op_prim_desc(op_desc, engine); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + dnnl::stream stream(engine); + + // provide memory buffers and check whether reorder is required + + // input + auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); + const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; + + // output + auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer()); + const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); + auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; + args[DNNL_ARG_DST] = z_mkl_mem; + + // run calculations + dnnl::pooling_forward(op_prim_desc).execute(stream, args); + + // reorder outputs if necessary + if (zReorder) + dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); + + stream.wait(); +} + +////////////////////////////////////////////////////////////////////// +void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, + const int kD, const int kH, const int kW, + const int sD, const int sH, const int sW, + const int pD, const int pH, const int pW, + const int isNCHW, const dnnl::algorithm mode) { + + // unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for input + + const int rank = input->rankOf(); + + int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; + dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims; + dnnl::memory::format_tag xzFrmat; + + const auto type = dnnl::memory::data_type::f32; + + if(rank == 4) { // 2d + + ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + strides = { sH, sW }; + kernel = { kH, kW }; + padding = { pH, pW }; + padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; + xDims = {bS, iC, iH, iW}; + zDims = {bS, oC, oH, oW}; + + xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + } + else { // 3d + + ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH); + + strides = { sD, sH, sW }; + kernel = { kD, kH, kW }; + padding = { pD, pH, pW }; + padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; + xDims = {bS, iC, iD, iH, iW}; + zDims = {bS, oC, oD, oH, oW}; + + xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc; + } + + // memory descriptors for arrays + + // input + dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat); + dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + if(input->ews() != 1 || input->ordering() != 'c') { + x_user_md.data.format_kind = dnnl_blocked; // overrides format + x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0); + x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1); + x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1); + x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2); + if(rank == 5) + x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3); + } + + // gradO + dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat); + if(gradO->ews() != 1 || gradO->ordering() != 'c') { + gradO_user_md.data.format_kind = dnnl_blocked; // overrides format + gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0); + gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 :-1); + gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1); + gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2); + if(rank == 5) + gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(isNCHW ? 4 : 3); + } + + // gradI + dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any); + dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat); + if(gradI->ews() != 1 || gradI->ordering() != 'c') { + gradI_user_md.data.format_kind = dnnl_blocked; // overrides format + gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0); + gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 :-1); + gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1); + gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2); + if(rank == 5) + gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(isNCHW ? 4 : 3); + } + + auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); + dnnl::stream stream(engine); + + // forward primitive description + dnnl::pooling_forward::desc op_ff_desc(dnnl::prop_kind::forward, mode, x_mkl_md, gradO_mkl_md, strides, kernel, padding, padding_r); + dnnl::pooling_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine); + + // backward primitive description + dnnl::pooling_backward::desc op_bp_desc(mode, gradI_mkl_md, gradO_mkl_md, strides, kernel, padding, padding_r); + dnnl::pooling_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc); + + // arguments (memory buffers) necessary for calculations + std::unordered_map args; + + // gradO + auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer()); + const bool gradOReorder = op_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc(); + auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem; + if (gradOReorder) + dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem); + args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem; + + // gradI + auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer()); + const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc(); + auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem; + args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem; + + if(mode == algorithm::pooling_max) { + + // input + auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer()); + const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc(); + auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem; + if (xReorder) + dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); + args[DNNL_ARG_SRC] = x_mkl_mem; + + // z + auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine); + args[DNNL_ARG_DST] = z_mkl_mem; + + // auxiliary memory allocation + auto workspace = dnnl::memory(op_ff_prim_desc.workspace_desc(), engine); + args[DNNL_ARG_WORKSPACE] = workspace; + + // run forward calculations + dnnl::pooling_forward(op_ff_prim_desc).execute(stream, args); + } + + // run backward calculations + dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args); + + + // reorder gradI if necessary + if (gradIReorder) + dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem); + + stream.wait(); +} + +////////////////////////////////////////////////////////////////////////// +void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, + dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { + const Nd4jLong* shape = src->getShapeInfo(); + long rank = shape[0]; + long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one + long dim2 = axis >= 2 ? 1 : 2; + long dim3 = axis >= 3 ? 2 : 3; + dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; + + auto type = dnnl::memory::data_type::f32; + auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; + auto supposed_to_be_any_format = format; // doesn't work with "any" + + if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) { + *lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; + user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; + user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; + user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; + user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; + } + + if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) { + *lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; + user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; + user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; + user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; + user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; + } + + if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) { + *lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; + user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; + user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; + user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; + user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; + } +} + +////////////////////////////////////////////////////////////////////////// +dnnl::engine& getEngine(void *ptr) { + auto eng = reinterpret_cast(ptr); + return *eng; +} + + +/* ////////////////////////////////////////////////////////////////////////// void getMKLDNNMemoryDescPool2d( int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW, @@ -307,104 +617,51 @@ void getMKLDNNMemoryDescConv3d( } }; - -// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, -// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, -// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { -// const Nd4jLong* shape = src->getShapeInfo(); -// Nd4jLong rank = shape[0]; -// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one -// Nd4jLong dim2 = axis >= 2 ? 1 : 2; -// Nd4jLong dim3 = axis >= 3 ? 2 : 3; -// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; - -// auto type = dnnl::memory::data_type::f32; -// auto format = dnnl::memory::format_tag::nchw; -// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" - -// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { -// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); -// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); -// user_src_md->data.format_kind = dnnl_blocked; // overrides format -// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; -// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; -// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; -// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; -// } - -// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { -// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); -// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); -// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format -// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; -// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; -// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; -// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; -// } - -// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { -// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); -// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); -// user_dst_md->data.format_kind = dnnl_blocked; // overrides format -// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; -// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; -// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; -// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; -// } -// }; - -////////////////////////////////////////////////////////////////////////// -void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { +void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, + dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) { const Nd4jLong* shape = src->getShapeInfo(); - long rank = shape[0]; - long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one - long dim2 = axis >= 2 ? 1 : 2; - long dim3 = axis >= 3 ? 2 : 3; - dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; + Nd4jLong rank = shape[0]; + Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one + Nd4jLong dim2 = axis >= 2 ? 1 : 2; + Nd4jLong dim3 = axis >= 3 ? 2 : 3; + dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1}; auto type = dnnl::memory::data_type::f32; - auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc; - auto supposed_to_be_any_format = format; // doesn't work with "any" + auto format = dnnl::memory::format_tag::nchw; + auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any" - if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) { - *lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_src_md->data.format_kind = dnnl_blocked; + if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) { + *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); + user_src_md->data.format_kind = dnnl_blocked; // overrides format user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0]; user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1]; user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1; user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1; } - if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) { - *lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_diff_src_md->data.format_kind = dnnl_blocked; + if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) { + *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); + user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0]; user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1]; user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1; user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1; } - if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) { - *lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format); - *user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format); - user_dst_md->data.format_kind = dnnl_blocked; + if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) { + *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format); + *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format); + user_dst_md->data.format_kind = dnnl_blocked; // overrides format user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0]; user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1]; user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1; user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1; } -} - -////////////////////////////////////////////////////////////////////////// -dnnl::engine& getEngine(void *ptr) { - auto eng = reinterpret_cast(ptr); - return *eng; -} - +}; +*/ } } \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index b55103a02..c8b34a6c0 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -16,6 +16,7 @@ // // @author saudet +// @author Yurii Shyrma (iuriish@yahoo.com) // #ifndef DEV_TESTS_MKLDNNUTILS_H @@ -81,17 +82,27 @@ namespace nd4j{ DECLARE_PLATFORM(deconv3d_bp, ENGINE_CPU); DECLARE_PLATFORM(depthwise_conv2d, ENGINE_CPU); - + DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU); } } namespace mkldnnUtils { + void poolingMKLDNN(const NDArray *input, NDArray *output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode); + + void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode); + + void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, + dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, + dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); + + dnnl::engine& getEngine(void *ptr); + /** * Utility methods for MKLDNN */ - void getMKLDNNMemoryDescConv2d( +/* void getMKLDNNMemoryDescConv2d( int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, @@ -130,12 +141,7 @@ namespace nd4j{ void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst, dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md, dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); - - void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst, - dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md, - dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis); - - dnnl::engine& getEngine(void *ptr); +*/ } } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp index 15524a901..795a7da4d 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -2031,121 +2031,6 @@ TEST_F(DeclarableOpsTests1, Sum1) { } */ -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, Avgpool2d_test1) { - - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - nd4j::ops::avgpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - ASSERT_TRUE(exp.isSameShape(result)); - - - delete variableSpace; - delete block; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, Avgpool2d_test2) { - const int bS = 2; - const int iD = 1; - const int iH = 28; - const int iW = 28; - const int kH = 5; - const int kW = 5; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height - const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width - - - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - nd4j::ops::avgpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result->printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); - - delete variableSpace; - delete block; -} - -////////////////////////////////////////////////////////////////////// -TEST_F(DeclarableOpsTests1, Avgpool2d_test3) { - const int bS = 2; - const int iD = 1; - const int iH = 28; - const int iW = 28; - const int kH = 5; - const int kW = 5; - const int sH = 1; - const int sW = 1; - const int pH = 0; - const int pW = 0; - const int dH = 1; - const int dW = 1; - const int oH = (int) nd4j::math::nd4j_ceil(iH * 1.f / sH); - const int oW = (int) nd4j::math::nd4j_ceil(iW * 1.f / sW); - - - auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); - auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); - // auto z('c',{bS,iD,oH,oW}); - - auto variableSpace = new VariableSpace(); - variableSpace->putVariable(-1, x); - // variableSpace->putVariable(1, &z); - - auto block = new Context(1, variableSpace, false); - block->fillInputs({-1}); - std::vector* argI = block->getIArguments(); - *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; - - nd4j::ops::avgpool2d pooling; - Nd4jStatus status = pooling.execute(block); - ASSERT_EQ(ND4J_STATUS_OK, status); - - auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); - // result->printShapeInfo(); - ASSERT_TRUE(exp.isSameShape(result)); - - delete variableSpace; - delete block; -} - - ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests1, Pnormpool2d1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp index 1e085d46c..f04d24395 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -360,7 +360,6 @@ TEST_F(DeclarableOpsTests4, avgpool2d_12) { 917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5, 1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5}); input.linspace(1.); - input.syncToDevice(); nd4j::ops::avgpool2d op; auto results = op.evaluate({&input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); @@ -377,6 +376,160 @@ TEST_F(DeclarableOpsTests4, avgpool2d_12) { delete results; } +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, avgpool2d_13) { + + const int bS = 2; // batch size + const int iD = 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height + const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width + + auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + nd4j::ops::avgpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(result)); + + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, avgpool2d_14) { + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height + const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width + + + auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + nd4j::ops::avgpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result->printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, Avgpool2d_test15) { + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (int) nd4j::math::nd4j_ceil(iH * 1.f / sH); + const int oW = (int) nd4j::math::nd4j_ceil(iW * 1.f / sW); + + + auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + nd4j::ops::avgpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result->printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, avgpool2d_16) { + + int bS=2, iH=4,iW=4, iC=2, kH=2,kW=2, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NHWC, 0-NDHW + + NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32); + NDArray output('f', {bS, oH, oW, iC}, nd4j::DataType::FLOAT32); + NDArray expected('c', {bS, oH, oW, iC}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}, nd4j::DataType::FLOAT32); + + input.linspace(1.); + + nd4j::ops::avgpool2d op; + auto status = op.execute({&input}, {&output}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}, {}); + + ASSERT_EQ(Status::OK(), status); + + // output.printBuffer(); + //expected.printIndexedBuffer("expected"); + + ASSERT_TRUE(expected.equalsTo(output)); +} + ////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests4, biasadd_1) { auto x = NDArrayFactory::create('c', {2, 3, 3, 2}); diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 970c119ca..4d7a0f783 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -422,50 +422,38 @@ TEST_F(PlaygroundTests, my) { delete variableSpace; } - -#include - TEST_F(PlaygroundTests, my) { - const int N = 10000; - const Nd4jLong dim0(128), dim1(128), dim2(128); + int N = 100; + int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=128,oW=128; - NDArray input('c', {dim0,dim1,dim2}, nd4j::DataType::DOUBLE); - NDArray mean('c', {dim1}, nd4j::DataType::DOUBLE); - NDArray variance('c', {dim1}, nd4j::DataType::DOUBLE); - NDArray gamma('c', {dim1}, nd4j::DataType::DOUBLE); - NDArray beta ('c', {dim1}, nd4j::DataType::DOUBLE); + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW - NDArray output('c', {dim0,dim1,dim2}, nd4j::DataType::DOUBLE); + // NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32); + // NDArray output('c', {bS, oC, oH, oW}, nd4j::DataType::FLOAT32); + NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32); + NDArray output('c', {bS, oH, oW, oC}, nd4j::DataType::FLOAT32); + // NDArray weights('c', {kH, kW, iC, oC}, nd4j::DataType::FLOAT32); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW] + NDArray weights('c', {oC, iC, kH, kW}, nd4j::DataType::FLOAT32); + NDArray bias('c', {oC}, nd4j::DataType::FLOAT32); - input.linspace(-100, 0.1); - mean.linspace(-50, 0.15); - variance.linspace(-5, 0.2); - gamma = 1.5; - beta = -2.5; + input = 5.; + weights = 3.; + bias = 1.; - // warm up - ops::helpers::batchnorm(&input, &mean, &variance, &gamma, &beta, &output, {1}, 1e-5); + nd4j::ops::conv2d op; + auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto timeStart = std::chrono::system_clock::now(); for (int i = 0; i < N; ++i) - ops::helpers::batchnorm(&input, &mean, &variance, &gamma, &beta, &output, {1}, 1e-5); - + err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); auto timeEnd = std::chrono::system_clock::now(); - auto time = std::chrono::duration_cast ((timeEnd - timeStart)/N).count(); - - printf("time: %li \n", time); + auto time = std::chrono::duration_cast ((timeEnd - timeStart) / N).count(); + printf("time: %i \n", time); } */ - - - - - - - - - From ce6848c9fe19340916f1f2c7b0a296da4c2e6146 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 7 Feb 2020 16:25:02 +1100 Subject: [PATCH 6/9] Test fixes (#218) * Test speedups / integration test run only for CUDA - NLP Signed-off-by: AlexDBlack * nlp-uima CUDA slow tests Signed-off-by: AlexDBlack * Spark CUDA timeout fixes Signed-off-by: AlexDBlack --- .../models/WordVectorSerializerTest.java | 23 +++-- .../models/word2vec/Word2VecTests.java | 56 +++++++++++ .../java/org/deeplearning4j/TsneTest.java | 12 ++- .../ParagraphVectorsTest.java | 96 +++++++++++++------ .../models/word2vec/Word2VecTestsSmall.java | 14 ++- .../iterator/Word2VecDataSetIteratorTest.java | 4 +- ...TestSparkMultiLayerParameterAveraging.java | 55 +++++++---- 7 files changed, 197 insertions(+), 63 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java index 7807ff711..0154bc732 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/WordVectorSerializerTest.java @@ -16,25 +16,19 @@ package org.deeplearning4j.models; -import org.junit.rules.Timeout; -import org.nd4j.shade.guava.io.Files; -import org.nd4j.shade.guava.primitives.Doubles; import lombok.val; import org.apache.commons.io.FileUtils; import org.apache.commons.lang.ArrayUtils; import org.apache.commons.lang3.RandomUtils; import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.models.sequencevectors.SequenceVectors; -import org.deeplearning4j.models.sequencevectors.serialization.VocabWordFactory; -import org.junit.Rule; -import org.junit.rules.TemporaryFolder; -import org.nd4j.linalg.io.ClassPathResource; import org.deeplearning4j.models.embeddings.WeightLookupTable; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration; import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer; import org.deeplearning4j.models.embeddings.wordvectors.WordVectors; import org.deeplearning4j.models.paragraphvectors.ParagraphVectors; +import org.deeplearning4j.models.sequencevectors.SequenceVectors; +import org.deeplearning4j.models.sequencevectors.serialization.VocabWordFactory; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.Word2Vec; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; @@ -48,11 +42,16 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFac import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.junit.Before; import org.junit.Ignore; +import org.junit.Rule; import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.Timeout; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.io.ClassPathResource; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.resources.Resources; +import org.nd4j.shade.guava.primitives.Doubles; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -272,7 +271,14 @@ public class WordVectorSerializerTest extends BaseDL4JTest { @Test public void testFullModelSerialization() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } + File inputFile = Resources.asFile("big/raw_sentences.txt"); + + SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath()); // Split on white spaces in the line to get words TokenizerFactory t = new DefaultTokenizerFactory(); @@ -892,5 +898,4 @@ public class WordVectorSerializerTest extends BaseDL4JTest { fail(e.getMessage()); } } - } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java index e50a95443..7dcfb160a 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp-uima/src/test/java/org/deeplearning4j/models/word2vec/Word2VecTests.java @@ -159,6 +159,11 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void testWord2VecCBOW() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } + SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); TokenizerFactory t = new DefaultTokenizerFactory(); @@ -188,6 +193,11 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void testWord2VecMultiEpoch() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } + SentenceIterator iter; if(isIntegrationTests()){ iter = new BasicLineIterator(inputFile.getAbsolutePath()); @@ -220,6 +230,11 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void reproducibleResults_ForMultipleRuns() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } + log.info("reproducibleResults_ForMultipleRuns"); val shakespear = new ClassPathResource("big/rnj.txt"); val basic = new ClassPathResource("big/rnj.txt"); @@ -274,6 +289,11 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void testRunWord2Vec() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } + // Strip white space before and after for each line /*val shakespear = new ClassPathResource("big/rnj.txt"); SentenceIterator iter = new BasicLineIterator(shakespear.getFile());*/ @@ -363,6 +383,11 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void testLoadingWordVectors() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } + File modelFile = new File(pathToWriteto); if (!modelFile.exists()) { testRunWord2Vec(); @@ -396,6 +421,11 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void testW2VnegativeOnRestore() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } + // Strip white space before and after for each line SentenceIterator iter; if(isIntegrationTests()){ @@ -453,6 +483,11 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void testUnknown1() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } + // Strip white space before and after for each line SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); // Split on white spaces in the line to get words @@ -688,6 +723,10 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void testWordVectorsPartiallyAbsentLabels() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); // Split on white spaces in the line to get words @@ -720,6 +759,10 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void testWordVectorsAbsentLabels() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); // Split on white spaces in the line to get words @@ -745,6 +788,10 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void testWordVectorsAbsentLabels_WithUnknown() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); // Split on white spaces in the line to get words @@ -814,6 +861,10 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void weightsNotUpdated_WhenLocked_CBOW() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath()); @@ -851,6 +902,11 @@ public class Word2VecTests extends BaseDL4JTest { @Test public void testWordsNearestSum() throws IOException { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X + } + log.info("Load & Vectorize Sentences...."); SentenceIterator iter = new BasicLineIterator(inputFile); TokenizerFactory t = new DefaultTokenizerFactory(); diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java index c99cb3b9a..cf0e7c7a3 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/TsneTest.java @@ -48,12 +48,22 @@ public class TsneTest extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 60000L; + return 180000L; } @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Override + public DataType getDataType() { + return DataType.FLOAT; + } + + @Override + public DataType getDefaultFPDataType() { + return DataType.FLOAT; + } + @Test public void testSimple() throws Exception { //Simple sanity check diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java index 95cd4e9a6..14495ffaf 100644 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/models/paragraphvectors/ParagraphVectorsTest.java @@ -32,6 +32,7 @@ import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.Par import org.deeplearning4j.text.sentenceiterator.*; import org.junit.Rule; import org.junit.rules.TemporaryFolder; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.io.ClassPathResource; import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable; import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram; @@ -80,12 +81,21 @@ public class ParagraphVectorsTest extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 240000; + return isIntegrationTests() ? 600_000 : 240_000; } @Rule public TemporaryFolder testDir = new TemporaryFolder(); + @Override + public DataType getDataType() { + return DataType.FLOAT; + } + + @Override + public DataType getDefaultFPDataType() { + return DataType.FLOAT; + } /* @Test @@ -359,8 +369,13 @@ public class ParagraphVectorsTest extends BaseDL4JTest { } - @Test(timeout = 300000) + @Test public void testParagraphVectorsDM() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed + } + File file = Resources.asFile("/big/raw_sentences.txt"); SentenceIterator iter = new BasicLineIterator(file); @@ -372,10 +387,10 @@ public class ParagraphVectorsTest extends BaseDL4JTest { LabelsSource source = new LabelsSource("DOC_"); ParagraphVectors vec = new ParagraphVectors.Builder().minWordFrequency(1).iterations(2).seed(119).epochs(1) - .layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter) - .trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0) - .useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true) - .sequenceLearningAlgorithm(new DM()).build(); + .layerSize(100).learningRate(0.025).labelsSource(source).windowSize(5).iterate(iter) + .trainWordVectors(true).vocabCache(cache).tokenizerFactory(t).negativeSample(0) + .useHierarchicSoftmax(true).sampling(0).workers(1).usePreciseWeightInit(true) + .sequenceLearningAlgorithm(new DM()).build(); vec.fit(); @@ -404,7 +419,9 @@ public class ParagraphVectorsTest extends BaseDL4JTest { double similarityX = vec.similarity("DOC_3720", "DOC_9852"); log.info("3720/9852 similarity: " + similarityX); - assertTrue(similarityX < 0.5d); + if(isIntegrationTests()) { + assertTrue(similarityX < 0.5d); + } // testing DM inference now @@ -418,7 +435,6 @@ public class ParagraphVectorsTest extends BaseDL4JTest { log.info("Cos O/A: {}", cosAO1); log.info("Cos A/B: {}", cosAB1); - } @@ -501,6 +517,11 @@ public class ParagraphVectorsTest extends BaseDL4JTest { @Test(timeout = 300000) public void testParagraphVectorsWithWordVectorsModelling1() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed + } + File file = Resources.asFile("/big/raw_sentences.txt"); SentenceIterator iter = new BasicLineIterator(file); @@ -705,8 +726,12 @@ public class ParagraphVectorsTest extends BaseDL4JTest { In this test we'll build w2v model, and will use it's vocab and weights for ParagraphVectors. there's no need in this test within travis, use it manually only for problems detection */ - @Test(timeout = 300000) + @Test public void testParagraphVectorsOverExistingWordVectorsModel() throws Exception { + String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend"); + if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) { + skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed + } // we build w2v from multiple sources, to cover everything File resource_sentences = Resources.asFile("/big/raw_sentences.txt"); @@ -997,14 +1022,18 @@ public class ParagraphVectorsTest extends BaseDL4JTest { log.info("SimilarityB: {}", simB); } - @Test(timeout = 300000) + @Test + @Ignore //AB 2020/02/06 - https://github.com/eclipse/deeplearning4j/issues/8677 public void testDirectInference() throws Exception { - File resource_sentences = Resources.asFile("/big/raw_sentences.txt"); + boolean isIntegration = isIntegrationTests(); + File resource = Resources.asFile("/big/raw_sentences.txt"); + SentenceIterator sentencesIter = getIterator(isIntegration, resource); + ClassPathResource resource_mixed = new ClassPathResource("paravec/"); File local_resource_mixed = testDir.newFolder(); resource_mixed.copyDirectory(local_resource_mixed); SentenceIterator iter = new AggregatingSentenceIterator.Builder() - .addSentenceIterator(new BasicLineIterator(resource_sentences)) + .addSentenceIterator(sentencesIter) .addSentenceIterator(new FileSentenceIterator(local_resource_mixed)).build(); TokenizerFactory t = new DefaultTokenizerFactory(); @@ -1154,24 +1183,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest { public void testDoubleFit() throws Exception { boolean isIntegration = isIntegrationTests(); File resource = Resources.asFile("/big/raw_sentences.txt"); - SentenceIterator iter; - if(isIntegration){ - iter = new BasicLineIterator(resource); - } else { - List lines = new ArrayList<>(); - try(InputStream is = new BufferedInputStream(new FileInputStream(resource))){ - LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8); - try{ - for( int i=0; i<500 && lineIter.hasNext(); i++ ){ - lines.add(lineIter.next()); - } - } finally { - lineIter.close(); - } - } - - iter = new CollectionSentenceIterator(lines); - } + SentenceIterator iter = getIterator(isIntegration, resource); TokenizerFactory t = new DefaultTokenizerFactory(); @@ -1197,6 +1209,30 @@ public class ParagraphVectorsTest extends BaseDL4JTest { assertEquals(num1, num2); } + + public static SentenceIterator getIterator(boolean isIntegration, File file) throws IOException { + return getIterator(isIntegration, file, 500); + } + + public static SentenceIterator getIterator(boolean isIntegration, File file, int linesForUnitTest) throws IOException { + if(isIntegration){ + return new BasicLineIterator(file); + } else { + List lines = new ArrayList<>(); + try(InputStream is = new BufferedInputStream(new FileInputStream(file))){ + LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8); + try{ + for( int i=0; i data = MLUtils .loadLibSVMFile(sc.sc(), @@ -125,7 +142,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } - @Test(timeout = 120000L) + @Test public void testFromSvmLight() throws Exception { JavaRDD data = MLUtils .loadLibSVMFile(sc.sc(), @@ -155,7 +172,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { master.fitLabeledPoint(data); } - @Test(timeout = 120000L) + @Test public void testRunIteration() { DataSet dataSet = new IrisDataSetIterator(5, 5).next(); @@ -175,7 +192,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { assertEquals(expectedParams.size(1), actualParams.size(1)); } - @Test(timeout = 120000L) + @Test public void testUpdaters() { SparkDl4jMultiLayer sparkNet = getBasicNetwork(); MultiLayerNetwork netCopy = sparkNet.getNetwork().clone(); @@ -197,7 +214,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } - @Test(timeout = 120000L) + @Test public void testEvaluation() { SparkDl4jMultiLayer sparkNet = getBasicNetwork(); @@ -228,7 +245,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } } - @Test(timeout = 120000L) + @Test public void testSmallAmountOfData() { //Idea: Test spark training where some executors don't get any data //in this case: by having fewer examples (2 DataSets) than executors (local[*]) @@ -255,7 +272,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } - @Test(timeout = 120000L) + @Test public void testDistributedScoring() { MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.1) @@ -333,7 +350,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { - @Test(timeout = 120000L) + @Test public void testParameterAveragingMultipleExamplesPerDataSet() throws Exception { int dataSetObjSize = 5; int batchSizePerExecutor = 25; @@ -382,7 +399,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } - @Test(timeout = 120000L) + @Test public void testFitViaStringPaths() throws Exception { Path tempDir = testDir.newFolder("DL4J-testFitViaStringPaths").toPath(); @@ -445,7 +462,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { sparkNet.getTrainingMaster().deleteTempFiles(sc); } - @Test(timeout = 120000L) + @Test public void testFitViaStringPathsSize1() throws Exception { Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsSize1").toPath(); @@ -525,7 +542,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } - @Test(timeout = 120000L) + @Test public void testFitViaStringPathsCompGraph() throws Exception { Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsCG").toPath(); @@ -618,7 +635,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } - @Test(timeout = 120000L) + @Test @Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue") public void testSeedRepeatability() throws Exception { @@ -691,7 +708,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } - @Test(timeout = 120000L) + @Test public void testIterationCounts() throws Exception { int dataSetObjSize = 5; int batchSizePerExecutor = 25; @@ -737,7 +754,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } } - @Test(timeout = 120000L) + @Test public void testIterationCountsGraph() throws Exception { int dataSetObjSize = 5; int batchSizePerExecutor = 25; @@ -783,7 +800,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } - @Test(timeout = 120000L) @Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656 + @Test + @Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656 public void testVaePretrainSimple() { //Simple sanity check on pretraining int nIn = 8; @@ -818,7 +836,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { sparkNet.fit(data); } - @Test(timeout = 120000L) @Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656 + @Test + @Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656 public void testVaePretrainSimpleCG() { //Simple sanity check on pretraining int nIn = 8; @@ -854,7 +873,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } - @Test(timeout = 120000L) + @Test public void testROC() { int nArrays = 100; @@ -909,7 +928,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest { } - @Test(timeout = 120000L) + @Test public void testROCMultiClass() { int nArrays = 100; From 937a27ae27f1ae66c528d3e2aeeef236eb656bac Mon Sep 17 00:00:00 2001 From: Alex Black Date: Fri, 7 Feb 2020 18:38:50 +1100 Subject: [PATCH 7/9] Small number of test fixes (#220) * Workaround for apache solr RNG algorithm enforcement in tests Signed-off-by: AlexDBlack * DeepWalk gradient check spam reduction Signed-off-by: AlexDBlack * Timeout increase Signed-off-by: AlexDBlack --- .../io/stream/TupleStreamDataSetIteratorTest.java | 14 ++++++++++++++ .../models/deepwalk/DeepWalkGradientCheck.java | 14 +++++++------- .../handler/ModelTupleStreamIntegrationTest.java | 14 ++++++++++++++ .../solr/handler/ModelTupleStreamTest.java | 13 +++++++++++++ .../java/org/deeplearning4j/zoo/MiscTests.java | 2 +- 5 files changed, 49 insertions(+), 8 deletions(-) diff --git a/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java b/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java index a35c612b1..b0757e67f 100644 --- a/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java +++ b/deeplearning4j/deeplearning4j-dataimport-solrj/src/test/java/org/deeplearning4j/nn/dataimport/solr/client/solrj/io/stream/TupleStreamDataSetIteratorTest.java @@ -18,6 +18,8 @@ package org.deeplearning4j.nn.dataimport.solr.client.solrj.io.stream; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; import com.carrotsearch.randomizedtesting.ThreadFilter; + +import java.security.SecureRandom; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -44,6 +46,18 @@ import org.nd4j.rng.deallocator.NativeRandomDeallocator; }) public class TupleStreamDataSetIteratorTest extends SolrCloudTestCase { + static { + /* + This is a hack around the backend-dependent nature of secure random implementations + though we can set the secure random algorithm in our pom.xml files (via maven surefire and test.solr.allowed.securerandom) + there isn't a mechanism that is completely platform independent. + By setting it there (for example, to NativePRNG) that makes it pass on some platforms like Linux but fails on some JVMs on Windows + For testing purposes, we don't need strict guarantees around RNG, hence we don't want to enforce the RNG algorithm + */ + String algorithm = new SecureRandom().getAlgorithm(); + System.setProperty("test.solr.allowed.securerandom", algorithm); + } + public static class PrivateDeallocatorThreadsFilter implements ThreadFilter { /** * Reject deallocator threads over whose cleanup this test has no control. diff --git a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java index 39e91921a..c1aedd47a 100644 --- a/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java +++ b/deeplearning4j/deeplearning4j-graph/src/test/java/org/deeplearning4j/graph/models/deepwalk/DeepWalkGradientCheck.java @@ -66,7 +66,7 @@ public class DeepWalkGradientCheck extends BaseDL4JTest { for (int i = 0; i < 7; i++) { INDArray vector = deepWalk.getVertexVector(i); assertArrayEquals(new long[] {vectorSize}, vector.shape()); - System.out.println(Arrays.toString(vector.dup().data().asFloat())); +// System.out.println(Arrays.toString(vector.dup().data().asFloat())); } GraphWalkIterator iter = new RandomWalkIterator<>(graph, 8); @@ -182,10 +182,10 @@ public class DeepWalkGradientCheck extends BaseDL4JTest { if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR) fail(msg); - else - System.out.println(msg); +// else +// System.out.println(msg); } - System.out.println(); +// System.out.println(); } } @@ -216,7 +216,7 @@ public class DeepWalkGradientCheck extends BaseDL4JTest { for (int i = 0; i < nVertices; i++) { INDArray vector = deepWalk.getVertexVector(i); assertArrayEquals(new long[] {vectorSize}, vector.shape()); - System.out.println(Arrays.toString(vector.dup().data().asFloat())); +// System.out.println(Arrays.toString(vector.dup().data().asFloat())); } GraphWalkIterator iter = new RandomWalkIterator<>(graph, 10); @@ -295,8 +295,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest { if (relError > MAX_REL_ERROR && absErr > minAbsError) fail(msg); - else - System.out.println(msg); +// else +// System.out.println(msg); } } diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java b/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java index 899e3f8fd..31633889a 100644 --- a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java +++ b/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamIntegrationTest.java @@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelexport.solr.handler; import java.io.File; import java.nio.file.Path; +import java.security.SecureRandom; import com.carrotsearch.randomizedtesting.ThreadFilter; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters; @@ -49,6 +50,19 @@ import org.nd4j.rng.deallocator.NativeRandomDeallocator; }) public class ModelTupleStreamIntegrationTest extends SolrCloudTestCase { + static { + /* + This is a hack around the backend-dependent nature of secure random implementations + though we can set the secure random algorithm in our pom.xml files (via maven surefire and test.solr.allowed.securerandom) + there isn't a mechanism that is completely platform independent. + By setting it there (for example, to NativePRNG) that makes it pass on some platforms like Linux but fails on some JVMs on Windows + For testing purposes, we don't need strict guarantees around RNG, hence we don't want to enforce the RNG algorithm + */ + String algorithm = new SecureRandom().getAlgorithm(); + System.setProperty("test.solr.allowed.securerandom", algorithm); + } + + public static class PrivateDeallocatorThreadsFilter implements ThreadFilter { /** * Reject deallocator threads over whose cleanup this test has no control. diff --git a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java b/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java index 1cafc143d..80073677f 100644 --- a/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java +++ b/deeplearning4j/deeplearning4j-modelexport-solr/src/test/java/org/deeplearning4j/nn/modelexport/solr/handler/ModelTupleStreamTest.java @@ -19,6 +19,7 @@ package org.deeplearning4j.nn.modelexport.solr.handler; import java.io.File; import java.nio.file.Files; import java.nio.file.Path; +import java.security.SecureRandom; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -58,6 +59,18 @@ import org.nd4j.linalg.lossfunctions.LossFunctions; public class ModelTupleStreamTest { + static { + /* + This is a hack around the backend-dependent nature of secure random implementations + though we can set the secure random algorithm in our pom.xml files (via maven surefire and test.solr.allowed.securerandom) + there isn't a mechanism that is completely platform independent. + By setting it there (for example, to NativePRNG) that makes it pass on some platforms like Linux but fails on some JVMs on Windows + For testing purposes, we don't need strict guarantees around RNG, hence we don't want to enforce the RNG algorithm + */ + String algorithm = new SecureRandom().getAlgorithm(); + System.setProperty("test.solr.allowed.securerandom", algorithm); + } + protected List floatsList(int numFloats) { final List floatsList = new ArrayList(); final float[] floats0 = new float[numFloats]; diff --git a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java index c92f5acdc..9dea6629a 100644 --- a/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java +++ b/deeplearning4j/deeplearning4j-zoo/src/test/java/org/deeplearning4j/zoo/MiscTests.java @@ -34,7 +34,7 @@ public class MiscTests extends BaseDL4JTest { @Override public long getTimeoutMilliseconds() { - return 120000L; + return 240000L; } @Test From a0da5a9e47ecb0e73874efe23effeb81aa68cb06 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 7 Feb 2020 12:34:55 +0300 Subject: [PATCH 8/9] Events removed from Java (#219) * replace mutex with lock_guards Signed-off-by: raver119 * Events ditched from Java CUDA logic Signed-off-by: raver119 --- .../include/helpers/cuda/ConstantHelper.cu | 8 ++--- .../helpers/cuda/ConstantShapeHelper.cu | 24 +++----------- .../include/helpers/cuda/ConstantTadHelper.cu | 4 +-- .../jita/allocator/impl/AtomicAllocator.java | 4 +-- .../nd4j/jita/concurrency/EventsProvider.java | 4 +-- .../flow/impl/SynchronousFlowController.java | 22 +++---------- .../jcublas/buffer/BaseCudaDataBuffer.java | 1 - .../java/org/nd4j/nativeblas/Nd4jCpu.java | 32 +++++++++++++++++++ 8 files changed, 49 insertions(+), 50 deletions(-) diff --git a/libnd4j/include/helpers/cuda/ConstantHelper.cu b/libnd4j/include/helpers/cuda/ConstantHelper.cu index 6c8eaa21d..47e276f4a 100644 --- a/libnd4j/include/helpers/cuda/ConstantHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantHelper.cu @@ -92,7 +92,7 @@ namespace nd4j { } void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) { - _mutex.lock(); + std::lock_guard lock(_mutex); auto deviceId = getCurrentDevice(); Nd4jPointer constantPtr = nullptr; @@ -116,7 +116,6 @@ namespace nd4j { if (res != 0) throw cuda_exception::build("cudaMemcpy failed", res); - _mutex.unlock(); return ptr; } else { auto originalBytes = numBytes; @@ -130,7 +129,6 @@ namespace nd4j { if (res != 0) throw cuda_exception::build("cudaMemcpyToSymbol failed", res); - _mutex.unlock(); return reinterpret_cast(constantPtr) + constantOffset; } } @@ -152,7 +150,7 @@ namespace nd4j { ConstantDataBuffer* result; // access to this holder instance is synchronous - holder->mutex()->lock(); + std::lock_guard lock(*holder->mutex()); if (holder->hasBuffer(dataType)) { result = holder->getConstantDataBuffer(dataType); @@ -175,8 +173,6 @@ namespace nd4j { holder->addBuffer(dataBuffer, dataType); result = holder->getConstantDataBuffer(dataType); } - // release holder lock - holder->mutex()->unlock(); return result; } diff --git a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu index aae62594c..4f7a4a485 100644 --- a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu @@ -57,7 +57,7 @@ namespace nd4j { ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) { int deviceId = AffinityManager::currentDeviceId(); - _mutex.lock(); + std::lock_guard lock(_mutex); if (_cache[deviceId].count(descriptor) == 0) { auto hPtr = descriptor.toShapeInfo(); @@ -65,15 +65,9 @@ namespace nd4j { ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64); ShapeDescriptor descriptor1(descriptor); _cache[deviceId][descriptor1] = buffer; - auto r = _cache[deviceId][descriptor1]; - _mutex.unlock(); - - return r; + return _cache[deviceId][descriptor1]; } else { - ConstantDataBuffer r = _cache[deviceId].at(descriptor); - _mutex.unlock(); - - return r; + return _cache[deviceId].at(descriptor); } } @@ -83,18 +77,10 @@ namespace nd4j { } bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) { - bool result; auto deviceId = AffinityManager::currentDeviceId(); - _mutex.lock(); + std::lock_guard lock(_mutex); - if (_cache[deviceId].count(descriptor) == 0) - result = false; - else - result = true; - - _mutex.unlock(); - - return result; + return _cache[deviceId].count(descriptor) != 0; } Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) { diff --git a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu index 8ea4067f3..747e295e2 100644 --- a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu @@ -64,7 +64,7 @@ namespace nd4j { TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) { const int deviceId = AffinityManager::currentDeviceId(); - _mutex.lock(); + std::lock_guard lock(_mutex); if (_cache[deviceId].count(descriptor) == 0) { const auto shapeInfo = descriptor.originalShape().toShapeInfo(); @@ -97,14 +97,12 @@ namespace nd4j { _cache[deviceId][descriptor] = t; TadPack r = _cache[deviceId][descriptor]; - _mutex.unlock(); delete[] shapeInfo; return r; } else { TadPack r = _cache[deviceId][descriptor]; - _mutex.unlock(); return r; } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java index aaccf9a34..46964c8f4 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/allocator/impl/AtomicAllocator.java @@ -469,8 +469,8 @@ public class AtomicAllocator implements Allocator { memoryHandler.purgeZeroObject(bucketId, objectId, point, copyback); - getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent()); - getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent()); + //getFlowController().getEventsProvider().storeEvent(point.getLastWriteEvent()); + //getFlowController().getEventsProvider().storeEvent(point.getLastReadEvent()); } /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/EventsProvider.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/EventsProvider.java index a7412fd76..7cc3e6838 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/EventsProvider.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/concurrency/EventsProvider.java @@ -26,11 +26,11 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.atomic.AtomicLong; /** + * * @author raver119@gmail.com */ +@Deprecated public class EventsProvider { - //private static final EventsProvider INSTANCE = new EventsProvider(); - private List> queue = new ArrayList<>(); private AtomicLong newCounter = new AtomicLong(0); private AtomicLong cacheCounter = new AtomicLong(0); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java index f5f68ea76..030ccad30 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/jita/flow/impl/SynchronousFlowController.java @@ -72,12 +72,7 @@ public class SynchronousFlowController implements FlowController { @Override public void waitTillFinished(AllocationPoint point) { - /*CudaContext context = point.getCurrentContext(); //(CudaContext) allocator.getDeviceContext().getContext(); - if (context == null) - context = (CudaContext) allocator.getDeviceContext().getContext(); - context.syncOldStream(); - */ - + // this should be always null, since synchronization happens in C++ now if (point.getLastWriteEvent() != null) { point.getLastWriteEvent().synchronize(); } @@ -181,8 +176,8 @@ public class SynchronousFlowController implements FlowController { @Override public void registerAction(CudaContext context, AllocationPoint result, AllocationPoint... operands) { - - + // this method is irrelevant now, everything happens in C++ now + /* eventsProvider.storeEvent(result.getLastWriteEvent()); result.setLastWriteEvent(eventsProvider.getEvent()); result.getLastWriteEvent().register(context.getOldStream()); @@ -194,6 +189,7 @@ public class SynchronousFlowController implements FlowController { operand.getLastReadEvent().register(context.getOldStream()); } // context.syncOldStream(); + */ } @Override @@ -204,9 +200,6 @@ public class SynchronousFlowController implements FlowController { val pointOperand = allocator.getAllocationPoint(operand); pointOperand.tickDeviceWrite(); - eventsProvider.storeEvent(pointOperand.getLastWriteEvent()); - pointOperand.setLastWriteEvent(eventsProvider.getEvent()); - pointOperand.getLastWriteEvent().register(context.getOldStream()); } } @@ -216,18 +209,13 @@ public class SynchronousFlowController implements FlowController { val point = allocator.getAllocationPoint(result); point.tickDeviceWrite(); - eventsProvider.storeEvent(point.getLastWriteEvent()); - point.setLastWriteEvent(eventsProvider.getEvent()); - point.getLastWriteEvent().register(context.getOldStream()); for (INDArray operand : operands) { if (operand == null || operand.isEmpty()) continue; val pointOperand = allocator.getAllocationPoint(operand); - eventsProvider.storeEvent(pointOperand.getLastReadEvent()); - pointOperand.setLastReadEvent(eventsProvider.getEvent()); - pointOperand.getLastReadEvent().register(context.getOldStream()); + pointOperand.tickDeviceRead(); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index 02b857f7f..cdec4e1be 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -307,7 +307,6 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda if (allocationPoint.getHostPointer() == null) { val location = allocationPoint.getAllocationStatus(); if (parentWorkspace == null) { - //log.info("dbAllocate step"); // let cpp allocate primary buffer NativeOpsHolder.getInstance().getDeviceNativeOps().dbAllocatePrimaryBuffer(ptrDataBuffer); } else { diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index b954a4a34..49d088f27 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -19177,6 +19177,38 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * solve op. - solve systems of linear equations - general method. + * + * input params: + * 0 - the tensor with dimension (x * y * z * ::: * M * M) - left parts of equations + * 1 - the tensor with dimension (x * y * z * ::: * M * K) - right parts of equations + * + * boolean args: + * 0 - adjoint - default is false (optional) - indicate input matrix or its adjoint (hermitian addition) should be used + * + * return value: + * tensor with dimension (x * y * z * ::: * M * K) with solutions + * + */ +// #if NOT_EXCLUDED(OP_solve) + @Namespace("nd4j::ops") public static class solve extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public solve(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public solve(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public solve position(long position) { + return (solve)super.position(position); + } + + public solve() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + /** * lu op. - make LUP decomposition of given batch of 2D square matricies * From 1dfac9a7362bdb787ab8a90781f5ad13db65d9b1 Mon Sep 17 00:00:00 2001 From: raver119 Date: Fri, 7 Feb 2020 18:16:11 +0300 Subject: [PATCH 9/9] DataBuffer.write() tweak (#221) * special workaround methods for DataBuffer.write Signed-off-by: raver119 * one test removed Signed-off-by: raver119 * more of unsynced Signed-off-by: raver119 * missing asLong for BaseCudaDataBuffer Signed-off-by: raver119 --- .../linalg/api/buffer/BaseDataBuffer.java | 31 +++++++++++-------- .../compression/CompressedDataBuffer.java | 20 ++++++++++++ .../jcublas/buffer/BaseCudaDataBuffer.java | 27 ++++++++++++++++ .../nativecpu/buffer/BaseCpuDataBuffer.java | 20 ++++++++++++ .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 25 --------------- 5 files changed, 85 insertions(+), 38 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java index 78b12e7fc..12e27e1c2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java @@ -780,7 +780,7 @@ public abstract class BaseDataBuffer implements DataBuffer { throw new IllegalArgumentException("Unable to create array of length " + length); float[] ret = new float[(int) length]; for (int i = 0; i < length; i++) - ret[i] = getFloat(i); + ret[i] = getFloatUnsynced(i); return ret; } @@ -790,7 +790,7 @@ public abstract class BaseDataBuffer implements DataBuffer { throw new IllegalArgumentException("Unable to create array of length " + length); double[] ret = new double[(int) length]; for (int i = 0; i < length; i++) - ret[i] = getDouble(i); + ret[i] = getDoubleUnsynced(i); return ret; } @@ -800,7 +800,7 @@ public abstract class BaseDataBuffer implements DataBuffer { throw new IllegalArgumentException("Unable to create array of length " + length); int[] ret = new int[(int) length]; for (int i = 0; i < length; i++) - ret[i] = getInt(i); + ret[i] = getIntUnsynced(i); return ret; } @@ -810,7 +810,7 @@ public abstract class BaseDataBuffer implements DataBuffer { throw new IllegalArgumentException("Unable to create array of length " + length); long[] ret = new long[(int) length]; for (int i = 0; i < length; i++) - ret[i] = getLong(i); + ret[i] = getLongUnsynced(i); return ret; } @@ -1662,6 +1662,11 @@ public abstract class BaseDataBuffer implements DataBuffer { } + protected abstract double getDoubleUnsynced(long index); + protected abstract float getFloatUnsynced(long index); + protected abstract long getLongUnsynced(long index); + protected abstract int getIntUnsynced(long index); + @Override public void write(DataOutputStream out) throws IOException { out.writeUTF(allocationMode.name()); @@ -1670,43 +1675,43 @@ public abstract class BaseDataBuffer implements DataBuffer { switch (dataType()) { case DOUBLE: for (long i = 0; i < length(); i++) - out.writeDouble(getDouble(i)); + out.writeDouble(getDoubleUnsynced(i)); break; case UINT64: case LONG: for (long i = 0; i < length(); i++) - out.writeLong(getLong(i)); + out.writeLong(getLongUnsynced(i)); break; case UINT32: case INT: for (long i = 0; i < length(); i++) - out.writeInt(getInt(i)); + out.writeInt(getIntUnsynced(i)); break; case UINT16: case SHORT: for (long i = 0; i < length(); i++) - out.writeShort((short) getInt(i)); + out.writeShort((short) getIntUnsynced(i)); break; case UBYTE: case BYTE: for (long i = 0; i < length(); i++) - out.writeByte((byte) getInt(i)); + out.writeByte((byte) getIntUnsynced(i)); break; case BOOL: for (long i = 0; i < length(); i++) - out.writeByte(getInt(i) == 0 ? (byte) 0 : (byte) 1); + out.writeByte(getIntUnsynced(i) == 0 ? (byte) 0 : (byte) 1); break; case BFLOAT16: for (long i = 0; i < length(); i++) - out.writeShort((short) Bfloat16Indexer.fromFloat(getFloat(i))); + out.writeShort((short) Bfloat16Indexer.fromFloat(getFloatUnsynced(i))); break; case HALF: for (long i = 0; i < length(); i++) - out.writeShort((short) HalfIndexer.fromFloat(getFloat(i))); + out.writeShort((short) HalfIndexer.fromFloat(getFloatUnsynced(i))); break; case FLOAT: for (long i = 0; i < length(); i++) - out.writeFloat(getFloat(i)); + out.writeFloat(getFloatUnsynced(i)); break; } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java index 0c822ce0a..f1c9ed6d9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/compression/CompressedDataBuffer.java @@ -210,4 +210,24 @@ public class CompressedDataBuffer extends BaseDataBuffer { public DataBuffer reallocate(long length) { throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer"); } + + @Override + protected double getDoubleUnsynced(long index) { + return super.getDouble(index); + } + + @Override + protected float getFloatUnsynced(long index) { + return super.getFloat(index); + } + + @Override + protected long getLongUnsynced(long index) { + return super.getLong(index); + } + + @Override + protected int getIntUnsynced(long index) { + return super.getInt(index); + } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java index cdec4e1be..2f1cab334 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/buffer/BaseCudaDataBuffer.java @@ -1287,6 +1287,26 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda @Override public void destroy() {} + @Override + protected double getDoubleUnsynced(long index) { + return super.getDouble(index); + } + + @Override + protected float getFloatUnsynced(long index) { + return super.getFloat(index); + } + + @Override + protected long getLongUnsynced(long index) { + return super.getLong(index); + } + + @Override + protected int getIntUnsynced(long index) { + return super.getInt(index); + } + @Override public void write(DataOutputStream out) throws IOException { lazyAllocateHostPointer(); @@ -1510,6 +1530,13 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda return super.asInt(); } + @Override + public long[] asLong() { + lazyAllocateHostPointer(); + allocator.synchronizeHostData(this); + return super.asLong(); + } + @Override public ByteBuffer asNio() { lazyAllocateHostPointer(); diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java index a51666f78..a5ddc7aef 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/BaseCpuDataBuffer.java @@ -208,6 +208,26 @@ public abstract class BaseCpuDataBuffer extends BaseDataBuffer implements Deallo Pointer.memcpy(ptr, temp, length * Nd4j.sizeOfDataType(dtype)); } + @Override + protected double getDoubleUnsynced(long index) { + return super.getDouble(index); + } + + @Override + protected float getFloatUnsynced(long index) { + return super.getFloat(index); + } + + @Override + protected long getLongUnsynced(long index) { + return super.getLong(index); + } + + @Override + protected int getIntUnsynced(long index) { + return super.getInt(index); + } + @Override public void pointerIndexerByCurrentType(DataType currentType) { diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index ad5bacc4e..d96c0ed31 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8262,31 +8262,6 @@ public class Nd4jTestsC extends BaseNd4jTest { assertArrayEquals(new long[]{10, 0}, out2.shape()); } - @Test - public void testDealloc_1() throws Exception { - - for (int e = 0; e < 5000; e++){ - try(val ws = Nd4j.getWorkspaceManager().getAndActivateWorkspace("someid")) { - val x = Nd4j.createUninitialized(DataType.FLOAT, 1, 1000); - //val y = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(0, 100)).reshape('c', 10, 10); - //val z = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(100, 200)).reshape('c', 10, 10); - //val a = x.get(NDArrayIndex.point(0), NDArrayIndex.interval(200, 300)).reshape('f', 10, 10); - } finally { - //System.gc(); - } - } - - Thread.sleep(1000); - System.gc(); - - Thread.sleep(1000); - System.gc(); - System.gc(); - System.gc(); - - //Nd4j.getMemoryManager().printRemainingStacks(); - } - @Override public char ordering() { return 'c';