commit
77401c7d33
|
@ -192,7 +192,7 @@
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<version>${maven-surefire-plugin.version}</version>
|
<version>${maven-surefire-plugin.version}</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<argLine>-Ddtype=double -Xmx3024m -Xms3024m</argLine>
|
<argLine>-Ddtype=double -Dfile.encoding=UTF-8 -Xmx3024m -Xms3024m</argLine>
|
||||||
<!--
|
<!--
|
||||||
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
||||||
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
|
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
|
||||||
|
|
|
@ -45,7 +45,7 @@ public class TypeConversion {
|
||||||
}
|
}
|
||||||
|
|
||||||
public int convertInt(String o) {
|
public int convertInt(String o) {
|
||||||
return Integer.parseInt(o);
|
return (int) Double.parseDouble(o);
|
||||||
}
|
}
|
||||||
|
|
||||||
public double convertDouble(Writable writable) {
|
public double convertDouble(Writable writable) {
|
||||||
|
|
|
@ -725,7 +725,6 @@ public class ArrowConverter {
|
||||||
case Time: ret.add(timeVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
|
case Time: ret.add(timeVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
|
||||||
case NDArray: ret.add(ndarrayVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
|
case NDArray: ret.add(ndarrayVectorOf(bufferAllocator,schema.getName(i),numRows)); break;
|
||||||
default: throw new IllegalArgumentException("Illegal type found for creation of field vectors" + schema.getType(i));
|
default: throw new IllegalArgumentException("Illegal type found for creation of field vectors" + schema.getType(i));
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -802,8 +801,13 @@ public class ArrowConverter {
|
||||||
//for proper offsets
|
//for proper offsets
|
||||||
ByteBuffer byteBuffer = BinarySerde.toByteBuffer(arr.get());
|
ByteBuffer byteBuffer = BinarySerde.toByteBuffer(arr.get());
|
||||||
nd4jArrayVector.setSafe(row,byteBuffer,0,byteBuffer.capacity());
|
nd4jArrayVector.setSafe(row,byteBuffer,0,byteBuffer.capacity());
|
||||||
|
case Boolean:
|
||||||
|
BitVector bitVector = (BitVector) fieldVector;
|
||||||
|
if(value instanceof Boolean)
|
||||||
|
bitVector.set(row, (boolean) value ? 1 : 0);
|
||||||
|
else
|
||||||
|
bitVector.set(row, ((BooleanWritable) value).get() ? 1 : 0);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
}
|
}
|
||||||
}catch(Exception e) {
|
}catch(Exception e) {
|
||||||
log.warn("Unable to set value at row " + row);
|
log.warn("Unable to set value at row " + row);
|
||||||
|
|
|
@ -315,7 +315,7 @@ public class TestPythonTransformProcess {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testNumpyTransform() throws Exception {
|
public void testNumpyTransform() {
|
||||||
PythonTransform pythonTransform = PythonTransform.builder()
|
PythonTransform pythonTransform = PythonTransform.builder()
|
||||||
.code("a += 2; b = 'hello world'")
|
.code("a += 2; b = 'hello world'")
|
||||||
.returnAllInputs(true)
|
.returnAllInputs(true)
|
||||||
|
@ -334,7 +334,42 @@ public class TestPythonTransformProcess {
|
||||||
assertFalse(execute.isEmpty());
|
assertFalse(execute.isEmpty());
|
||||||
assertNotNull(execute.get(0));
|
assertNotNull(execute.get(0));
|
||||||
assertNotNull(execute.get(0).get(0));
|
assertNotNull(execute.get(0).get(0));
|
||||||
assertEquals("hello world",execute.get(0).get(0).toString());
|
assertNotNull(execute.get(0).get(1));
|
||||||
|
assertEquals(Nd4j.scalar(3).reshape(1, 1),((NDArrayWritable)execute.get(0).get(0)).get());
|
||||||
|
assertEquals("hello world",execute.get(0).get(1).toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testWithSetupRun() throws Exception {
|
||||||
|
|
||||||
|
PythonTransform pythonTransform = PythonTransform.builder()
|
||||||
|
.code("five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b + five\n"+
|
||||||
|
" return {'c':c}\n\n")
|
||||||
|
.returnAllInputs(true)
|
||||||
|
.setupAndRun(true)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
List<List<Writable>> inputs = new ArrayList<>();
|
||||||
|
inputs.add(Arrays.asList((Writable) new NDArrayWritable(Nd4j.scalar(1).reshape(1,1)),
|
||||||
|
new NDArrayWritable(Nd4j.scalar(2).reshape(1,1))));
|
||||||
|
Schema inputSchema = new Builder()
|
||||||
|
.addColumnNDArray("a",new long[]{1,1})
|
||||||
|
.addColumnNDArray("b", new long[]{1, 1})
|
||||||
|
.build();
|
||||||
|
|
||||||
|
TransformProcess tp = new TransformProcess.Builder(inputSchema)
|
||||||
|
.transform(pythonTransform)
|
||||||
|
.build();
|
||||||
|
List<List<Writable>> execute = LocalTransformExecutor.execute(inputs, tp);
|
||||||
|
assertFalse(execute.isEmpty());
|
||||||
|
assertNotNull(execute.get(0));
|
||||||
|
assertNotNull(execute.get(0).get(0));
|
||||||
|
assertEquals(Nd4j.scalar(8).reshape(1, 1),((NDArrayWritable)execute.get(0).get(3)).get());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
|
@ -29,6 +29,7 @@ import org.nd4j.nativeblas.NativeOps;
|
||||||
import org.nd4j.nativeblas.NativeOpsHolder;
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
|
||||||
|
import static org.nd4j.linalg.api.buffer.DataType.FLOAT;
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -46,6 +47,7 @@ public class NumpyArray {
|
||||||
private long[] strides;
|
private long[] strides;
|
||||||
private DataType dtype;
|
private DataType dtype;
|
||||||
private INDArray nd4jArray;
|
private INDArray nd4jArray;
|
||||||
|
|
||||||
static {
|
static {
|
||||||
//initialize
|
//initialize
|
||||||
Nd4j.scalar(1.0);
|
Nd4j.scalar(1.0);
|
||||||
|
@ -53,32 +55,6 @@ public class NumpyArray {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
public NumpyArray(long address, long[] shape, long strides[], boolean copy,DataType dtype) {
|
|
||||||
this.address = address;
|
|
||||||
this.shape = shape;
|
|
||||||
this.strides = strides;
|
|
||||||
this.dtype = dtype;
|
|
||||||
setND4JArray();
|
|
||||||
if (copy){
|
|
||||||
nd4jArray = nd4jArray.dup();
|
|
||||||
Nd4j.getAffinityManager().ensureLocation(nd4jArray, AffinityManager.Location.HOST);
|
|
||||||
this.address = nd4jArray.data().address();
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public NumpyArray copy(){
|
|
||||||
return new NumpyArray(nd4jArray.dup());
|
|
||||||
}
|
|
||||||
|
|
||||||
public NumpyArray(long address, long[] shape, long strides[]){
|
|
||||||
this(address, shape, strides, false,DataType.FLOAT);
|
|
||||||
}
|
|
||||||
|
|
||||||
public NumpyArray(long address, long[] shape, long strides[], DataType dtype){
|
|
||||||
this(address, shape, strides, dtype, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public NumpyArray(long address, long[] shape, long strides[], DataType dtype, boolean copy) {
|
public NumpyArray(long address, long[] shape, long strides[], DataType dtype, boolean copy) {
|
||||||
this.address = address;
|
this.address = address;
|
||||||
this.shape = shape;
|
this.shape = shape;
|
||||||
|
@ -92,6 +68,21 @@ public class NumpyArray {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
public NumpyArray copy() {
|
||||||
|
return new NumpyArray(nd4jArray.dup());
|
||||||
|
}
|
||||||
|
|
||||||
|
public NumpyArray(long address, long[] shape, long strides[]) {
|
||||||
|
this(address, shape, strides, FLOAT, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
public NumpyArray(long address, long[] shape, long strides[], DataType dtype) {
|
||||||
|
this(address, shape, strides, dtype, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
private void setND4JArray() {
|
private void setND4JArray() {
|
||||||
long size = 1;
|
long size = 1;
|
||||||
for (long d : shape) {
|
for (long d : shape) {
|
||||||
|
|
|
@ -0,0 +1,265 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import org.bytedeco.cpython.PyObject;
|
||||||
|
|
||||||
|
import static org.bytedeco.cpython.global.python.*;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Swift like python wrapper for Java
|
||||||
|
*
|
||||||
|
* @author Fariz Rahman
|
||||||
|
*/
|
||||||
|
|
||||||
|
public class Python {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Imports a python module, similar to python import statement.
|
||||||
|
* @param moduleName name of the module to be imported
|
||||||
|
* @return reference to the module object
|
||||||
|
* @throws PythonException
|
||||||
|
*/
|
||||||
|
public static PythonObject importModule(String moduleName) throws PythonException{
|
||||||
|
PythonObject module = new PythonObject(PyImport_ImportModule(moduleName));
|
||||||
|
if (module.isNone()) {
|
||||||
|
throw new PythonException("Error importing module: " + moduleName);
|
||||||
|
}
|
||||||
|
return module;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject attr(String attrName) {
|
||||||
|
return builtins().attr(attrName);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject len(PythonObject pythonObject) {
|
||||||
|
return attr("len").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject str(PythonObject pythonObject) {
|
||||||
|
return attr("str").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject str() {
|
||||||
|
return attr("str").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject strType() {
|
||||||
|
return attr("str");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject float_(PythonObject pythonObject) {
|
||||||
|
return attr("float").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject float_() {
|
||||||
|
return attr("float").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject floatType() {
|
||||||
|
return attr("float");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bool(PythonObject pythonObject) {
|
||||||
|
return attr("bool").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bool() {
|
||||||
|
return attr("bool").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject boolType() {
|
||||||
|
return attr("bool");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject int_(PythonObject pythonObject) {
|
||||||
|
return attr("int").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject int_() {
|
||||||
|
return attr("int").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject intType() {
|
||||||
|
return attr("int");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject list(PythonObject pythonObject) {
|
||||||
|
return attr("list").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject list() {
|
||||||
|
return attr("list").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject listType() {
|
||||||
|
return attr("list");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject dict(PythonObject pythonObject) {
|
||||||
|
return attr("dict").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject dict() {
|
||||||
|
return attr("dict").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject dictType() {
|
||||||
|
return attr("dict");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject set(PythonObject pythonObject) {
|
||||||
|
return attr("set").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject set() {
|
||||||
|
return attr("set").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bytearray(PythonObject pythonObject) {
|
||||||
|
return attr("bytearray").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bytearray() {
|
||||||
|
return attr("bytearray").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bytearrayType() {
|
||||||
|
return attr("bytearray");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bytes(PythonObject pythonObject) {
|
||||||
|
return attr("bytes").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bytes() {
|
||||||
|
return attr("bytes").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject bytesType() {
|
||||||
|
return attr("bytes");
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject tuple(PythonObject pythonObject) {
|
||||||
|
return attr("tuple").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject tuple() {
|
||||||
|
return attr("tuple").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static PythonObject Exception(PythonObject pythonObject) {
|
||||||
|
return attr("Exception").call(pythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject Exception() {
|
||||||
|
return attr("Exception").call();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject ExceptionType() {
|
||||||
|
return attr("Exception");
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static PythonObject tupleType() {
|
||||||
|
return attr("tuple");
|
||||||
|
}
|
||||||
|
public static PythonObject globals() {
|
||||||
|
return new PythonObject(PyModule_GetDict(PyImport_ImportModule("__main__")));
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject type(PythonObject obj) {
|
||||||
|
return attr("type").call(obj);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean isinstance(PythonObject obj, PythonObject... type) {
|
||||||
|
return PyObject_IsInstance(obj.getNativePythonObject(),
|
||||||
|
PyList_AsTuple(new PythonObject(type).getNativePythonObject())) != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject eval(String code) {
|
||||||
|
PyObject compiledCode = Py_CompileString(code, "", Py_eval_input);
|
||||||
|
PyObject globals = globals().getNativePythonObject();
|
||||||
|
PyObject locals = Python.dict().getNativePythonObject();
|
||||||
|
return new PythonObject(PyEval_EvalCode(compiledCode, globals, locals));
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static PythonObject builtins(){
|
||||||
|
try{
|
||||||
|
return importModule("builtins");
|
||||||
|
}catch (PythonException pe){
|
||||||
|
throw new IllegalStateException("Unable to import builtins: " + pe); // this should never happen
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject None() {
|
||||||
|
return dict().attr("get").call(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject True() {
|
||||||
|
return boolType().call(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonObject False() {
|
||||||
|
return boolType().call(0);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean callable(PythonObject pythonObject) {
|
||||||
|
return PyCallable_Check(pythonObject.getNativePythonObject()) == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static void setContext(String context) throws PythonException{
|
||||||
|
PythonContextManager.setContext(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String getCurrentContext(){
|
||||||
|
return PythonContextManager.getCurrentContext();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void deleteContext(String context) throws PythonException{
|
||||||
|
PythonContextManager.deleteContext(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void deleteNonMainContexts(){
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void setMainContext(){PythonContextManager.setMainContext();}
|
||||||
|
|
||||||
|
public static void exec(String code)throws PythonException{
|
||||||
|
PythonExecutioner.exec(code);
|
||||||
|
}
|
||||||
|
public static void exec(String code, PythonVariables inputs) throws PythonException{
|
||||||
|
PythonExecutioner.exec(code, inputs);
|
||||||
|
}
|
||||||
|
public static void exec(String code, PythonVariables inputs, PythonVariables outputs) throws PythonException{
|
||||||
|
PythonExecutioner.exec(code, inputs, outputs);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonGIL lock(){
|
||||||
|
return PythonGIL.lock();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -16,14 +16,14 @@
|
||||||
|
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
|
|
||||||
import org.datavec.api.transform.ColumnType;
|
|
||||||
import org.datavec.api.transform.condition.Condition;
|
import org.datavec.api.transform.condition.Condition;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
import static org.datavec.python.PythonUtils.schemaToPythonVariables;
|
import static org.datavec.python.PythonUtils.schemaToPythonVariables;
|
||||||
|
import static org.nd4j.base.Preconditions.checkNotNull;
|
||||||
|
import static org.nd4j.base.Preconditions.checkState;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Lets a condition be defined as a python method f that takes no arguments
|
* Lets a condition be defined as a python method f that takes no arguments
|
||||||
|
@ -41,14 +41,12 @@ public class PythonCondition implements Condition {
|
||||||
|
|
||||||
|
|
||||||
public PythonCondition(String pythonCode) {
|
public PythonCondition(String pythonCode) {
|
||||||
org.nd4j.base.Preconditions.checkNotNull("Python code must not be null!",pythonCode);
|
checkNotNull("Python code must not be null!", pythonCode);
|
||||||
org.nd4j.base.Preconditions.checkState(pythonCode.length() >= 1,"Python code must not be empty!");
|
checkState(!pythonCode.isEmpty(), "Python code must not be empty!");
|
||||||
code = pythonCode;
|
code = pythonCode;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void setInputSchema(Schema inputSchema) {
|
public void setInputSchema(Schema inputSchema) {
|
||||||
this.inputSchema = inputSchema;
|
this.inputSchema = inputSchema;
|
||||||
|
@ -62,13 +60,11 @@ public class PythonCondition implements Condition {
|
||||||
.outputs(pyOuts)
|
.outputs(pyOuts)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
}
|
} catch (Exception e) {
|
||||||
catch (Exception e){
|
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -107,11 +103,10 @@ public class PythonCondition implements Condition {
|
||||||
public boolean condition(List<Writable> list) {
|
public boolean condition(List<Writable> list) {
|
||||||
PythonVariables inputs = getPyInputsFromWritables(list);
|
PythonVariables inputs = getPyInputsFromWritables(list);
|
||||||
try {
|
try {
|
||||||
PythonExecutioner.exec(pythonTransform.getCode(), inputs, pythonTransform.getOutputs());
|
pythonTransform.getPythonJob().exec(inputs, pythonTransform.getOutputs());
|
||||||
boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0;
|
boolean ret = pythonTransform.getOutputs().getIntValue("out") != 0;
|
||||||
return ret;
|
return ret;
|
||||||
}
|
} catch (Exception e) {
|
||||||
catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -138,13 +133,12 @@ public class PythonCondition implements Condition {
|
||||||
for (int i = 0; i < inputSchema.numColumns(); i++) {
|
for (int i = 0; i < inputSchema.numColumns(); i++) {
|
||||||
String name = inputSchema.getName(i);
|
String name = inputSchema.getName(i);
|
||||||
Writable w = writables.get(i);
|
Writable w = writables.get(i);
|
||||||
PythonVariables.Type pyType = pyInputs.getType(inputSchema.getName(i));
|
PythonType pyType = pyInputs.getType(inputSchema.getName(i));
|
||||||
switch (pyType){
|
switch (pyType.getName()) {
|
||||||
case INT:
|
case INT:
|
||||||
if (w instanceof LongWritable) {
|
if (w instanceof LongWritable) {
|
||||||
ret.addInt(name, ((LongWritable) w).get());
|
ret.addInt(name, ((LongWritable) w).get());
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
ret.addInt(name, ((IntWritable) w).get());
|
ret.addInt(name, ((IntWritable) w).get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,188 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Emulates multiples interpreters in a single interpreter.
|
||||||
|
* This works by simply obfuscating/de-obfuscating variable names
|
||||||
|
* such that only the required subset of the global namespace is "visible"
|
||||||
|
* at any given time.
|
||||||
|
* By default, there exists a "main" context emulating the default interpreter
|
||||||
|
* and cannot be deleted.
|
||||||
|
* @author Fariz Rahman
|
||||||
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
public class PythonContextManager {
|
||||||
|
|
||||||
|
private static Set<String> contexts = new HashSet<>();
|
||||||
|
private static AtomicBoolean init = new AtomicBoolean(false);
|
||||||
|
private static String currentContext;
|
||||||
|
private static final String MAIN_CONTEXT = "main";
|
||||||
|
static {
|
||||||
|
init();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void init() {
|
||||||
|
if (init.get()) return;
|
||||||
|
new PythonExecutioner();
|
||||||
|
init.set(true);
|
||||||
|
currentContext = MAIN_CONTEXT;
|
||||||
|
contexts.add(currentContext);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static void addContext(String contextName) throws PythonException {
|
||||||
|
if (!validateContextName(contextName)) {
|
||||||
|
throw new PythonException("Invalid context name: " + contextName);
|
||||||
|
}
|
||||||
|
contexts.add(contextName);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static boolean hasContext(String contextName) {
|
||||||
|
return contexts.contains(contextName);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public static boolean validateContextName(String s) {
|
||||||
|
if (s.length() == 0) return false;
|
||||||
|
if (!Character.isJavaIdentifierStart(s.charAt(0))) return false;
|
||||||
|
for (int i = 1; i < s.length(); i++)
|
||||||
|
if (!Character.isJavaIdentifierPart(s.charAt(i)))
|
||||||
|
return false;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String getContextPrefix(String contextName) {
|
||||||
|
return "__collapsed__" + contextName + "__";
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String getCollapsedVarNameForContext(String varName, String contextName) {
|
||||||
|
return getContextPrefix(contextName) + varName;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static String expandCollapsedVarName(String varName, String contextName) {
|
||||||
|
String prefix = "__collapsed__" + contextName + "__";
|
||||||
|
return varName.substring(prefix.length());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void collapseContext(String contextName) {
|
||||||
|
PythonObject globals = Python.globals();
|
||||||
|
PythonObject keysList = Python.list(globals.attr("keys").call());
|
||||||
|
int numKeys = Python.len(keysList).toInt();
|
||||||
|
for (int i = 0; i < numKeys; i++) {
|
||||||
|
PythonObject key = keysList.get(i);
|
||||||
|
String keyStr = key.toString();
|
||||||
|
if (!((keyStr.startsWith("__") && keyStr.endsWith("__")) || keyStr.startsWith("__collapsed_"))) {
|
||||||
|
String collapsedKey = getCollapsedVarNameForContext(keyStr, contextName);
|
||||||
|
PythonObject val = globals.attr("pop").call(key);
|
||||||
|
globals.set(new PythonObject(collapsedKey), val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void expandContext(String contextName) {
|
||||||
|
String prefix = getContextPrefix(contextName);
|
||||||
|
PythonObject globals = Python.globals();
|
||||||
|
PythonObject keysList = Python.list(globals.attr("keys").call());
|
||||||
|
int numKeys = Python.len(keysList).toInt();
|
||||||
|
for (int i = 0; i < numKeys; i++) {
|
||||||
|
PythonObject key = keysList.get(i);
|
||||||
|
String keyStr = key.toString();
|
||||||
|
if (keyStr.startsWith(prefix)) {
|
||||||
|
String expandedKey = expandCollapsedVarName(keyStr, contextName);
|
||||||
|
PythonObject val = globals.attr("pop").call(key);
|
||||||
|
globals.set(new PythonObject(expandedKey), val);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void setContext(String contextName) throws PythonException{
|
||||||
|
if (contextName.equals(currentContext)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!hasContext(contextName)) {
|
||||||
|
addContext(contextName);
|
||||||
|
}
|
||||||
|
collapseContext(currentContext);
|
||||||
|
expandContext(contextName);
|
||||||
|
currentContext = contextName;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void setMainContext() {
|
||||||
|
try{
|
||||||
|
setContext(MAIN_CONTEXT);
|
||||||
|
}
|
||||||
|
catch (PythonException pe){
|
||||||
|
throw new RuntimeException(pe);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public static String getCurrentContext() {
|
||||||
|
return currentContext;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void deleteContext(String contextName) throws PythonException {
|
||||||
|
if (contextName.equals(MAIN_CONTEXT)) {
|
||||||
|
throw new PythonException("Can not delete main context!");
|
||||||
|
}
|
||||||
|
if (contextName.equals(currentContext)) {
|
||||||
|
throw new PythonException("Can not delete current context!");
|
||||||
|
}
|
||||||
|
String prefix = getContextPrefix(contextName);
|
||||||
|
PythonObject globals = Python.globals();
|
||||||
|
PythonObject keysList = Python.list(globals.attr("keys").call());
|
||||||
|
int numKeys = Python.len(keysList).toInt();
|
||||||
|
for (int i = 0; i < numKeys; i++) {
|
||||||
|
PythonObject key = keysList.get(i);
|
||||||
|
String keyStr = key.toString();
|
||||||
|
if (keyStr.startsWith(prefix)) {
|
||||||
|
globals.attr("__delitem__").call(key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contexts.remove(contextName);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static void deleteNonMainContexts() {
|
||||||
|
try{
|
||||||
|
setContext(MAIN_CONTEXT); // will never fail
|
||||||
|
for (String c : contexts.toArray(new String[0])) {
|
||||||
|
if (!c.equals(MAIN_CONTEXT)) {
|
||||||
|
deleteContext(c); // will never fail
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}catch(Exception e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public String[] getContexts() {
|
||||||
|
return contexts.toArray(new String[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,44 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Thrown when an exception occurs in python land
|
||||||
|
*/
|
||||||
|
public class PythonException extends Exception {
|
||||||
|
public PythonException(String message){
|
||||||
|
super(message);
|
||||||
|
}
|
||||||
|
private static String getExceptionString(PythonObject exception){
|
||||||
|
if (Python.isinstance(exception, Python.ExceptionType())){
|
||||||
|
String exceptionClass = Python.type(exception).attr("__name__").toString();
|
||||||
|
String message = exception.toString();
|
||||||
|
return exceptionClass + ": " + message;
|
||||||
|
}
|
||||||
|
return exception.toString();
|
||||||
|
}
|
||||||
|
public PythonException(PythonObject exception){
|
||||||
|
this(getExceptionString(exception));
|
||||||
|
}
|
||||||
|
public PythonException(String message, Throwable cause){
|
||||||
|
super(message, cause);
|
||||||
|
}
|
||||||
|
public PythonException(Throwable cause){
|
||||||
|
super(cause);
|
||||||
|
}
|
||||||
|
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,68 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.bytedeco.cpython.PyThreadState;
|
||||||
|
|
||||||
|
import static org.bytedeco.cpython.global.python.*;
|
||||||
|
import static org.bytedeco.cpython.global.python.PyEval_RestoreThread;
|
||||||
|
import static org.bytedeco.cpython.global.python.PyEval_SaveThread;
|
||||||
|
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class PythonGIL implements AutoCloseable {
|
||||||
|
private static PyThreadState mainThreadState;
|
||||||
|
|
||||||
|
static {
|
||||||
|
log.debug("CPython: PyThreadState_Get()");
|
||||||
|
mainThreadState = PyThreadState_Get();
|
||||||
|
}
|
||||||
|
|
||||||
|
private PythonGIL() {
|
||||||
|
acquire();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
release();
|
||||||
|
}
|
||||||
|
|
||||||
|
public static PythonGIL lock() {
|
||||||
|
return new PythonGIL();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static synchronized void acquire() {
|
||||||
|
log.debug("acquireGIL()");
|
||||||
|
log.debug("CPython: PyEval_SaveThread()");
|
||||||
|
mainThreadState = PyEval_SaveThread();
|
||||||
|
log.debug("CPython: PyThreadState_New()");
|
||||||
|
PyThreadState ts = PyThreadState_New(mainThreadState.interp());
|
||||||
|
log.debug("CPython: PyEval_RestoreThread()");
|
||||||
|
PyEval_RestoreThread(ts);
|
||||||
|
log.debug("CPython: PyThreadState_Swap()");
|
||||||
|
PyThreadState_Swap(ts);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static synchronized void release() {
|
||||||
|
log.debug("CPython: PyEval_SaveThread()");
|
||||||
|
PyEval_SaveThread();
|
||||||
|
log.debug("CPython: PyEval_RestoreThread()");
|
||||||
|
PyEval_RestoreThread(mainThreadState);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,171 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.Data;
|
||||||
|
import lombok.NoArgsConstructor;
|
||||||
|
|
||||||
|
import javax.annotation.Nonnull;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
|
||||||
|
@Data
|
||||||
|
@NoArgsConstructor
|
||||||
|
/**
|
||||||
|
* PythonJob is the right abstraction for executing multiple python scripts
|
||||||
|
* in a multi thread stateful environment. The setup-and-run mode allows your
|
||||||
|
* "setup" code (imports, model loading etc) to be executed only once.
|
||||||
|
*/
|
||||||
|
public class PythonJob {
|
||||||
|
|
||||||
|
private String code;
|
||||||
|
private String name;
|
||||||
|
private String context;
|
||||||
|
private boolean setupRunMode;
|
||||||
|
private PythonObject runF;
|
||||||
|
|
||||||
|
static {
|
||||||
|
new PythonExecutioner();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
/**
|
||||||
|
* @param name Name for the python job.
|
||||||
|
* @param code Python code.
|
||||||
|
* @param setupRunMode If true, the python code is expected to have two methods: setup(), which takes no arguments,
|
||||||
|
* and run() which takes some or no arguments. setup() method is executed once,
|
||||||
|
* and the run() method is called with the inputs(if any) per transaction, and is expected to return a dictionary
|
||||||
|
* mapping from output variable names (str) to output values.
|
||||||
|
* If false, the full script is run on each transaction and the output variables are obtained from the global namespace
|
||||||
|
* after execution.
|
||||||
|
*/
|
||||||
|
public PythonJob(@Nonnull String name, @Nonnull String code, boolean setupRunMode) throws Exception {
|
||||||
|
this.name = name;
|
||||||
|
this.code = code;
|
||||||
|
this.setupRunMode = setupRunMode;
|
||||||
|
context = "__job_" + name;
|
||||||
|
if (PythonContextManager.hasContext(context)) {
|
||||||
|
throw new PythonException("Unable to create python job " + name + ". Context " + context + " already exists!");
|
||||||
|
}
|
||||||
|
if (setupRunMode) setup();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clears all variables in current context and calls setup()
|
||||||
|
*/
|
||||||
|
public void clearState() throws Exception {
|
||||||
|
String context = this.context;
|
||||||
|
PythonContextManager.setContext("main");
|
||||||
|
PythonContextManager.deleteContext(context);
|
||||||
|
this.context = context;
|
||||||
|
setup();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void setup() throws Exception {
|
||||||
|
try (PythonGIL gil = PythonGIL.lock()) {
|
||||||
|
PythonContextManager.setContext(context);
|
||||||
|
PythonObject runF = PythonExecutioner.getVariable("run");
|
||||||
|
if (runF.isNone() || !Python.callable(runF)) {
|
||||||
|
PythonExecutioner.exec(code);
|
||||||
|
runF = PythonExecutioner.getVariable("run");
|
||||||
|
}
|
||||||
|
if (runF.isNone() || !Python.callable(runF)) {
|
||||||
|
throw new PythonException("run() method not found! " +
|
||||||
|
"If a PythonJob is created with 'setup and run' " +
|
||||||
|
"mode enabled, the associated python code is " +
|
||||||
|
"expected to contain a run() method " +
|
||||||
|
"(with or without arguments).");
|
||||||
|
}
|
||||||
|
this.runF = runF;
|
||||||
|
PythonObject setupF = PythonExecutioner.getVariable("setup");
|
||||||
|
if (!setupF.isNone()) {
|
||||||
|
setupF.call();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public void exec(PythonVariables inputs, PythonVariables outputs) throws Exception {
|
||||||
|
try (PythonGIL gil = PythonGIL.lock()) {
|
||||||
|
PythonContextManager.setContext(context);
|
||||||
|
if (!setupRunMode) {
|
||||||
|
PythonExecutioner.exec(code, inputs, outputs);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
PythonExecutioner.setVariables(inputs);
|
||||||
|
|
||||||
|
PythonObject inspect = Python.importModule("inspect");
|
||||||
|
PythonObject getfullargspec = inspect.attr("getfullargspec");
|
||||||
|
PythonObject argspec = getfullargspec.call(runF);
|
||||||
|
PythonObject argsList = argspec.attr("args");
|
||||||
|
PythonObject runargs = Python.dict();
|
||||||
|
int argsCount = Python.len(argsList).toInt();
|
||||||
|
for (int i = 0; i < argsCount; i++) {
|
||||||
|
PythonObject arg = argsList.get(i);
|
||||||
|
PythonObject val = Python.globals().get(arg);
|
||||||
|
if (val.isNone()) {
|
||||||
|
throw new PythonException("Input value not received for run() argument: " + arg.toString());
|
||||||
|
}
|
||||||
|
runargs.set(arg, val);
|
||||||
|
}
|
||||||
|
PythonObject outDict = runF.callWithKwargs(runargs);
|
||||||
|
Python.globals().attr("update").call(outDict);
|
||||||
|
|
||||||
|
PythonExecutioner.getVariables(outputs);
|
||||||
|
inspect.del();
|
||||||
|
getfullargspec.del();
|
||||||
|
argspec.del();
|
||||||
|
runargs.del();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonVariables execAndReturnAllVariables(PythonVariables inputs) throws Exception {
|
||||||
|
try (PythonGIL gil = PythonGIL.lock()) {
|
||||||
|
PythonContextManager.setContext(context);
|
||||||
|
if (!setupRunMode) {
|
||||||
|
return PythonExecutioner.execAndReturnAllVariables(code, inputs);
|
||||||
|
}
|
||||||
|
PythonExecutioner.setVariables(inputs);
|
||||||
|
PythonObject inspect = Python.importModule("inspect");
|
||||||
|
PythonObject getfullargspec = inspect.attr("getfullargspec");
|
||||||
|
PythonObject argspec = getfullargspec.call(runF);
|
||||||
|
PythonObject argsList = argspec.attr("args");
|
||||||
|
PythonObject runargs = Python.dict();
|
||||||
|
int argsCount = Python.len(argsList).toInt();
|
||||||
|
for (int i = 0; i < argsCount; i++) {
|
||||||
|
PythonObject arg = argsList.get(i);
|
||||||
|
PythonObject val = Python.globals().get(arg);
|
||||||
|
if (val.isNone()) {
|
||||||
|
throw new PythonException("Input value not received for run() argument: " + arg.toString());
|
||||||
|
}
|
||||||
|
runargs.set(arg, val);
|
||||||
|
}
|
||||||
|
PythonObject outDict = runF.callWithKwargs(runargs);
|
||||||
|
Python.globals().attr("update").call(outDict);
|
||||||
|
inspect.del();
|
||||||
|
getfullargspec.del();
|
||||||
|
argspec.del();
|
||||||
|
runargs.del();
|
||||||
|
return PythonExecutioner.getAllVariables();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,554 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import org.bytedeco.cpython.PyObject;
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.json.JSONArray;
|
||||||
|
import org.json.JSONObject;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
import static org.bytedeco.cpython.global.python.*;
|
||||||
|
import static org.bytedeco.cpython.global.python.PyObject_SetItem;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Swift like python wrapper for J
|
||||||
|
*
|
||||||
|
* @author Fariz Rahman
|
||||||
|
*/
|
||||||
|
|
||||||
|
public class PythonObject {
|
||||||
|
private PyObject nativePythonObject;
|
||||||
|
|
||||||
|
static {
|
||||||
|
new PythonExecutioner();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static Map<String, PythonObject> _getNDArraySerializer() {
|
||||||
|
Map<String, PythonObject> ndarraySerializer = new HashMap<>();
|
||||||
|
PythonObject lambda = Python.eval(
|
||||||
|
"lambda x: " +
|
||||||
|
"{'address':" +
|
||||||
|
"x.__array_interface__['data'][0]," +
|
||||||
|
"'shape':x.shape,'strides':x.strides," +
|
||||||
|
"'dtype': str(x.dtype),'_is_numpy_array': True}" +
|
||||||
|
" if str(type(x))== \"<class 'numpy.ndarray'>\" else x");
|
||||||
|
ndarraySerializer.put("default",
|
||||||
|
lambda);
|
||||||
|
return ndarraySerializer;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(PyObject pyObject) {
|
||||||
|
nativePythonObject = pyObject;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(INDArray npArray) {
|
||||||
|
this(new NumpyArray(npArray));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(BytePointer bp){
|
||||||
|
nativePythonObject = PyByteArray_FromStringAndSize(bp, bp.capacity());
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(NumpyArray npArray) {
|
||||||
|
PyObject ctypes = PyImport_ImportModule("ctypes");
|
||||||
|
PyObject np = PyImport_ImportModule("numpy");
|
||||||
|
PyObject ctype;
|
||||||
|
switch (npArray.getDtype()) {
|
||||||
|
case DOUBLE:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_double");
|
||||||
|
break;
|
||||||
|
case FLOAT:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_float");
|
||||||
|
break;
|
||||||
|
case LONG:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_int64");
|
||||||
|
break;
|
||||||
|
case INT:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_int32");
|
||||||
|
break;
|
||||||
|
case SHORT:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_int16");
|
||||||
|
break;
|
||||||
|
case UINT16:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_uint16");
|
||||||
|
break;
|
||||||
|
case UINT32:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_uint32");
|
||||||
|
break;
|
||||||
|
case UINT64:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_uint64");
|
||||||
|
break;
|
||||||
|
case BOOL:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_bool");
|
||||||
|
break;
|
||||||
|
case BYTE:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_byte");
|
||||||
|
break;
|
||||||
|
case UBYTE:
|
||||||
|
ctype = PyObject_GetAttrString(ctypes, "c_ubyte");
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Unsupported dtype: " + npArray.getDtype());
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject ctypesPointer = PyObject_GetAttrString(ctypes, "POINTER");
|
||||||
|
PyObject argsTuple = PyTuple_New(1);
|
||||||
|
PyTuple_SetItem(argsTuple, 0, ctype);
|
||||||
|
PyObject ptrType = PyObject_Call(ctypesPointer, argsTuple, null);
|
||||||
|
|
||||||
|
PyObject cast = PyObject_GetAttrString(ctypes, "cast");
|
||||||
|
PyObject address = PyLong_FromLong(npArray.getAddress());
|
||||||
|
PyObject argsTuple2 = PyTuple_New(2);
|
||||||
|
PyTuple_SetItem(argsTuple2, 0, address);
|
||||||
|
PyTuple_SetItem(argsTuple2, 1, ptrType);
|
||||||
|
PyObject ptr = PyObject_Call(cast, argsTuple2, null);
|
||||||
|
PyObject shapeTuple = PyTuple_New(npArray.getShape().length);
|
||||||
|
for (int i = 0; i < npArray.getShape().length; i++) {
|
||||||
|
PyObject dim = PyLong_FromLong(npArray.getShape()[i]);
|
||||||
|
PyTuple_SetItem(shapeTuple, i, dim);
|
||||||
|
Py_DecRef(dim);
|
||||||
|
}
|
||||||
|
PyObject ctypesLib = PyObject_GetAttrString(np, "ctypeslib");
|
||||||
|
PyObject asArray = PyObject_GetAttrString(ctypesLib, "as_array");
|
||||||
|
PyObject argsTuple3 = PyTuple_New(2);
|
||||||
|
PyTuple_SetItem(argsTuple3, 0, ptr);
|
||||||
|
PyTuple_SetItem(argsTuple3, 1, shapeTuple);
|
||||||
|
nativePythonObject = PyObject_Call(asArray, argsTuple3, null);
|
||||||
|
|
||||||
|
Py_DecRef(ctypesPointer);
|
||||||
|
Py_DecRef(ctypesLib);
|
||||||
|
Py_DecRef(argsTuple);
|
||||||
|
Py_DecRef(argsTuple2);
|
||||||
|
Py_DecRef(argsTuple3);
|
||||||
|
Py_DecRef(cast);
|
||||||
|
Py_DecRef(asArray);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
/*---primitve constructors---*/
|
||||||
|
public PyObject getNativePythonObject() {
|
||||||
|
return nativePythonObject;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(String data) {
|
||||||
|
nativePythonObject = PyUnicode_FromString(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(int data) {
|
||||||
|
nativePythonObject = PyLong_FromLong((long) data);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(long data) {
|
||||||
|
nativePythonObject = PyLong_FromLong(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(double data) {
|
||||||
|
nativePythonObject = PyFloat_FromDouble(data);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(boolean data) {
|
||||||
|
nativePythonObject = PyBool_FromLong(data ? 1 : 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static PythonObject j2pyObject(Object item) {
|
||||||
|
if (item instanceof PythonObject) {
|
||||||
|
return (PythonObject) item;
|
||||||
|
} else if (item instanceof PyObject) {
|
||||||
|
return new PythonObject((PyObject) item);
|
||||||
|
} else if (item instanceof INDArray) {
|
||||||
|
return new PythonObject((INDArray) item);
|
||||||
|
} else if (item instanceof NumpyArray) {
|
||||||
|
return new PythonObject((NumpyArray) item);
|
||||||
|
} else if (item instanceof List) {
|
||||||
|
return new PythonObject((List) item);
|
||||||
|
} else if (item instanceof Object[]) {
|
||||||
|
return new PythonObject((Object[]) item);
|
||||||
|
} else if (item instanceof Map) {
|
||||||
|
return new PythonObject((Map) item);
|
||||||
|
} else if (item instanceof String) {
|
||||||
|
return new PythonObject((String) item);
|
||||||
|
} else if (item instanceof Double) {
|
||||||
|
return new PythonObject((Double) item);
|
||||||
|
} else if (item instanceof Float) {
|
||||||
|
return new PythonObject((Float) item);
|
||||||
|
} else if (item instanceof Long) {
|
||||||
|
return new PythonObject((Long) item);
|
||||||
|
} else if (item instanceof Integer) {
|
||||||
|
return new PythonObject((Integer) item);
|
||||||
|
} else if (item instanceof Boolean) {
|
||||||
|
return new PythonObject((Boolean) item);
|
||||||
|
} else if (item instanceof Pointer){
|
||||||
|
return new PythonObject(new BytePointer((Pointer)item));
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Unsupported item in list: " + item);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(Object[] data) {
|
||||||
|
PyObject pyList = PyList_New((long) data.length);
|
||||||
|
for (int i = 0; i < data.length; i++) {
|
||||||
|
PyList_SetItem(pyList, i, j2pyObject(data[i]).nativePythonObject);
|
||||||
|
}
|
||||||
|
nativePythonObject = pyList;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(List data) {
|
||||||
|
PyObject pyList = PyList_New((long) data.size());
|
||||||
|
for (int i = 0; i < data.size(); i++) {
|
||||||
|
PyList_SetItem(pyList, i, j2pyObject(data.get(i)).nativePythonObject);
|
||||||
|
}
|
||||||
|
nativePythonObject = pyList;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject(Map data) {
|
||||||
|
PyObject pyDict = PyDict_New();
|
||||||
|
for (Object k : data.keySet()) {
|
||||||
|
PythonObject pyKey;
|
||||||
|
if (k instanceof PythonObject) {
|
||||||
|
pyKey = (PythonObject) k;
|
||||||
|
} else if (k instanceof String) {
|
||||||
|
pyKey = new PythonObject((String) k);
|
||||||
|
} else if (k instanceof Double) {
|
||||||
|
pyKey = new PythonObject((Double) k);
|
||||||
|
} else if (k instanceof Float) {
|
||||||
|
pyKey = new PythonObject((Float) k);
|
||||||
|
} else if (k instanceof Long) {
|
||||||
|
pyKey = new PythonObject((Long) k);
|
||||||
|
} else if (k instanceof Integer) {
|
||||||
|
pyKey = new PythonObject((Integer) k);
|
||||||
|
} else if (k instanceof Boolean) {
|
||||||
|
pyKey = new PythonObject((Boolean) k);
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Unsupported key in map: " + k.getClass());
|
||||||
|
}
|
||||||
|
Object v = data.get(k);
|
||||||
|
PythonObject pyVal;
|
||||||
|
if (v instanceof PythonObject) {
|
||||||
|
pyVal = (PythonObject) v;
|
||||||
|
} else if (v instanceof PyObject) {
|
||||||
|
pyVal = new PythonObject((PyObject) v);
|
||||||
|
} else if (v instanceof INDArray) {
|
||||||
|
pyVal = new PythonObject((INDArray) v);
|
||||||
|
} else if (v instanceof NumpyArray) {
|
||||||
|
pyVal = new PythonObject((NumpyArray) v);
|
||||||
|
} else if (v instanceof Map) {
|
||||||
|
pyVal = new PythonObject((Map) v);
|
||||||
|
} else if (v instanceof List) {
|
||||||
|
pyVal = new PythonObject((List) v);
|
||||||
|
} else if (v instanceof String) {
|
||||||
|
pyVal = new PythonObject((String) v);
|
||||||
|
} else if (v instanceof Double) {
|
||||||
|
pyVal = new PythonObject((Double) v);
|
||||||
|
} else if (v instanceof Float) {
|
||||||
|
pyVal = new PythonObject((Float) v);
|
||||||
|
} else if (v instanceof Long) {
|
||||||
|
pyVal = new PythonObject((Long) v);
|
||||||
|
} else if (v instanceof Integer) {
|
||||||
|
pyVal = new PythonObject((Integer) v);
|
||||||
|
} else if (v instanceof Boolean) {
|
||||||
|
pyVal = new PythonObject((Boolean) v);
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Unsupported value in map: " + k.getClass());
|
||||||
|
}
|
||||||
|
|
||||||
|
PyDict_SetItem(pyDict, pyKey.nativePythonObject, pyVal.nativePythonObject);
|
||||||
|
|
||||||
|
}
|
||||||
|
nativePythonObject = pyDict;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*------*/
|
||||||
|
|
||||||
|
private static String pyObjectToString(PyObject pyObject) {
|
||||||
|
PyObject repr = PyObject_Str(pyObject);
|
||||||
|
PyObject str = PyUnicode_AsEncodedString(repr, "utf-8", "~E~");
|
||||||
|
String jstr = PyBytes_AsString(str).getString();
|
||||||
|
Py_DecRef(repr);
|
||||||
|
Py_DecRef(str);
|
||||||
|
return jstr;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String toString() {
|
||||||
|
return pyObjectToString(nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public double toDouble() {
|
||||||
|
return PyFloat_AsDouble(nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public float toFloat() {
|
||||||
|
return (float) PyFloat_AsDouble(nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int toInt() {
|
||||||
|
return (int) PyLong_AsLong(nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public long toLong() {
|
||||||
|
return PyLong_AsLong(nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean toBoolean() {
|
||||||
|
if (isNone()) return false;
|
||||||
|
return toInt() != 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public NumpyArray toNumpy() {
|
||||||
|
PyObject arrInterface = PyObject_GetAttrString(nativePythonObject, "__array_interface__"); // borrowed reference; DO NOT Py_DecRef() !
|
||||||
|
PyObject data = PyDict_GetItemString(arrInterface, "data");
|
||||||
|
PyObject pyAddress = PyTuple_GetItem(data, 0);
|
||||||
|
long address = PyLong_AsLong(pyAddress);
|
||||||
|
PyObject pyDtype = PyObject_GetAttrString(nativePythonObject, "dtype");
|
||||||
|
PyObject pyDtypeName = PyObject_GetAttrString(pyDtype, "name");
|
||||||
|
String dtypeName = pyObjectToString(pyDtypeName);
|
||||||
|
Py_DecRef(pyDtype);
|
||||||
|
Py_DecRef(pyDtypeName);
|
||||||
|
PyObject shape = PyObject_GetAttrString(nativePythonObject, "shape");
|
||||||
|
PyObject strides = PyObject_GetAttrString(nativePythonObject, "strides");
|
||||||
|
int ndim = (int) PyObject_Size(shape);
|
||||||
|
long[] jshape = new long[ndim];
|
||||||
|
long[] jstrides = new long[ndim];
|
||||||
|
for (int i = 0; i < ndim; i++) {
|
||||||
|
jshape[i] = PyLong_AsLong(PyTuple_GetItem(shape, i));
|
||||||
|
jstrides[i] = PyLong_AsLong(PyTuple_GetItem(strides, i));
|
||||||
|
}
|
||||||
|
Py_DecRef(shape);
|
||||||
|
Py_DecRef(strides);
|
||||||
|
DataType dtype;
|
||||||
|
if (dtypeName.equals("float64")) {
|
||||||
|
dtype = DataType.DOUBLE;
|
||||||
|
} else if (dtypeName.equals("float32")) {
|
||||||
|
dtype = DataType.FLOAT;
|
||||||
|
} else if (dtypeName.equals("int16")) {
|
||||||
|
dtype = DataType.SHORT;
|
||||||
|
} else if (dtypeName.equals("int32")) {
|
||||||
|
dtype = DataType.INT;
|
||||||
|
} else if (dtypeName.equals("int64")) {
|
||||||
|
dtype = DataType.LONG;
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||||
|
}
|
||||||
|
return new NumpyArray(address, jshape, jstrides, dtype);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject attr(String attr) {
|
||||||
|
|
||||||
|
return new PythonObject(PyObject_GetAttrString(nativePythonObject, attr));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject call(Object... args) {
|
||||||
|
if (args.length > 0 && args[args.length - 1] instanceof Map) {
|
||||||
|
List<Object> args2 = new ArrayList<>();
|
||||||
|
for (int i = 0; i < args.length - 1; i++) {
|
||||||
|
args2.add(args[i]);
|
||||||
|
}
|
||||||
|
return call(args2, (Map) args[args.length - 1]);
|
||||||
|
}
|
||||||
|
if (args.length == 0) {
|
||||||
|
return new PythonObject(PyObject_CallObject(nativePythonObject, null));
|
||||||
|
}
|
||||||
|
PyObject tuple = PyTuple_New(args.length); // leaky; tuple may contain borrowed references, so can not be de-allocated.
|
||||||
|
for (int i = 0; i < args.length; i++) {
|
||||||
|
PyTuple_SetItem(tuple, i, j2pyObject(args[i]).nativePythonObject);
|
||||||
|
}
|
||||||
|
PythonObject ret = new PythonObject(PyObject_Call(nativePythonObject, tuple, null));
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject callWithArgs(PythonObject args) {
|
||||||
|
PyObject tuple = PyList_AsTuple(args.nativePythonObject);
|
||||||
|
return new PythonObject(PyObject_Call(nativePythonObject, tuple, null));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject callWithKwargs(PythonObject kwargs) {
|
||||||
|
PyObject tuple = PyTuple_New(0);
|
||||||
|
return new PythonObject(PyObject_Call(nativePythonObject, tuple, kwargs.nativePythonObject));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject callWithArgsAndKwargs(PythonObject args, PythonObject kwargs) {
|
||||||
|
PyObject tuple = PyList_AsTuple(args.nativePythonObject);
|
||||||
|
PyObject dict = kwargs.nativePythonObject;
|
||||||
|
return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject call(Map kwargs) {
|
||||||
|
PyObject dict = new PythonObject(kwargs).nativePythonObject;
|
||||||
|
PyObject tuple = PyTuple_New(0);
|
||||||
|
return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject call(List args) {
|
||||||
|
PyObject tuple = PyList_AsTuple(new PythonObject(args).nativePythonObject);
|
||||||
|
return new PythonObject(PyObject_Call(nativePythonObject, tuple, null));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject call(List args, Map kwargs) {
|
||||||
|
PyObject tuple = PyList_AsTuple(new PythonObject(args).nativePythonObject);
|
||||||
|
PyObject dict = new PythonObject(kwargs).nativePythonObject;
|
||||||
|
return new PythonObject(PyObject_Call(nativePythonObject, tuple, dict));
|
||||||
|
}
|
||||||
|
|
||||||
|
private PythonObject get(PyObject key) {
|
||||||
|
return new PythonObject(
|
||||||
|
PyObject_GetItem(nativePythonObject, key)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject get(PythonObject key) {
|
||||||
|
return get(key.nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public PythonObject get(int key) {
|
||||||
|
return get(PyLong_FromLong((long) key));
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject get(long key) {
|
||||||
|
return new PythonObject(
|
||||||
|
PyObject_GetItem(nativePythonObject, PyLong_FromLong(key))
|
||||||
|
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject get(double key) {
|
||||||
|
return new PythonObject(
|
||||||
|
PyObject_GetItem(nativePythonObject, PyFloat_FromDouble(key))
|
||||||
|
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
public PythonObject get(String key) {
|
||||||
|
return get(new PythonObject(key));
|
||||||
|
}
|
||||||
|
|
||||||
|
public void set(PythonObject key, PythonObject value) {
|
||||||
|
PyObject_SetItem(nativePythonObject, key.nativePythonObject, value.nativePythonObject);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void del() {
|
||||||
|
Py_DecRef(nativePythonObject);
|
||||||
|
nativePythonObject = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
public JSONArray toJSONArray() throws PythonException {
|
||||||
|
PythonObject json = Python.importModule("json");
|
||||||
|
PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer());
|
||||||
|
String jsonString = serialized.toString();
|
||||||
|
return new JSONArray(jsonString);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
public JSONObject toJSONObject() throws PythonException {
|
||||||
|
PythonObject json = Python.importModule("json");
|
||||||
|
PythonObject serialized = json.attr("dumps").call(this, _getNDArraySerializer());
|
||||||
|
String jsonString = serialized.toString();
|
||||||
|
return new JSONObject(jsonString);
|
||||||
|
}
|
||||||
|
|
||||||
|
public List toList() throws PythonException{
|
||||||
|
List list = new ArrayList();
|
||||||
|
int n = Python.len(this).toInt();
|
||||||
|
for (int i = 0; i < n; i++) {
|
||||||
|
PythonObject o = get(i);
|
||||||
|
if (Python.isinstance(o, Python.strType())) {
|
||||||
|
list.add(o.toString());
|
||||||
|
} else if (Python.isinstance(o, Python.intType())) {
|
||||||
|
list.add(o.toLong());
|
||||||
|
} else if (Python.isinstance(o, Python.floatType())) {
|
||||||
|
list.add(o.toDouble());
|
||||||
|
} else if (Python.isinstance(o, Python.boolType())) {
|
||||||
|
list.add(o);
|
||||||
|
} else if (Python.isinstance(o, Python.listType(), Python.tupleType())) {
|
||||||
|
list.add(o.toList());
|
||||||
|
} else if (Python.isinstance(o, Python.importModule("numpy").attr("ndarray"))) {
|
||||||
|
list.add(o.toNumpy().getNd4jArray());
|
||||||
|
} else if (Python.isinstance(o, Python.dictType())) {
|
||||||
|
list.add(o.toMap());
|
||||||
|
} else {
|
||||||
|
throw new RuntimeException("Error while converting python" +
|
||||||
|
" list to java List: Unable to serialize python " +
|
||||||
|
"object of type " + Python.type(this).toString());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return list;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map toMap() throws PythonException{
|
||||||
|
Map map = new HashMap();
|
||||||
|
List keys = Python.list(attr("keys").call()).toList();
|
||||||
|
List values = Python.list(attr("values").call()).toList();
|
||||||
|
for (int i = 0; i < keys.size(); i++) {
|
||||||
|
map.put(keys.get(i), values.get(i));
|
||||||
|
}
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
|
public BytePointer toBytePointer() throws PythonException{
|
||||||
|
if (Python.isinstance(this, Python.bytesType())){
|
||||||
|
PyObject byteArray = PyByteArray_FromObject(nativePythonObject);
|
||||||
|
return PyByteArray_AsString(byteArray);
|
||||||
|
|
||||||
|
}
|
||||||
|
else if (Python.isinstance(this, Python.bytearrayType())){
|
||||||
|
return PyByteArray_AsString(nativePythonObject);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
PyObject ctypes = PyImport_ImportModule("ctypes");
|
||||||
|
PyObject cArrType = PyObject_GetAttrString(ctypes, "Array");
|
||||||
|
if (PyObject_IsInstance(nativePythonObject, cArrType) != 0){
|
||||||
|
PyObject cVoidP = PyObject_GetAttrString(ctypes, "c_void_p");
|
||||||
|
PyObject cast = PyObject_GetAttrString(ctypes, "cast");
|
||||||
|
PyObject argsTuple = PyTuple_New(2);
|
||||||
|
PyTuple_SetItem(argsTuple, 0, nativePythonObject);
|
||||||
|
PyTuple_SetItem(argsTuple, 1, cVoidP);
|
||||||
|
PyObject voidPtr = PyObject_Call(cast, argsTuple, null);
|
||||||
|
PyObject pyAddress = PyObject_GetAttrString(voidPtr, "value");
|
||||||
|
long address = PyLong_AsLong(pyAddress);
|
||||||
|
long size = PyObject_Size(nativePythonObject);
|
||||||
|
Py_DecRef(ctypes);
|
||||||
|
Py_DecRef(cArrType);
|
||||||
|
Py_DecRef(argsTuple);
|
||||||
|
Py_DecRef(voidPtr);
|
||||||
|
Py_DecRef(pyAddress);
|
||||||
|
Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress(address);
|
||||||
|
ptr = ptr.limit(size);
|
||||||
|
ptr = ptr.capacity(size);
|
||||||
|
return new BytePointer(ptr);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
throw new PythonException("Expected bytes, bytearray or ctypesArray. Received " + Python.type(this).toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
public boolean isNone() {
|
||||||
|
return nativePythonObject == null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -24,10 +24,13 @@ import org.datavec.api.transform.ColumnType;
|
||||||
import org.datavec.api.transform.Transform;
|
import org.datavec.api.transform.Transform;
|
||||||
import org.datavec.api.transform.schema.Schema;
|
import org.datavec.api.transform.schema.Schema;
|
||||||
import org.datavec.api.writable.*;
|
import org.datavec.api.writable.*;
|
||||||
|
import org.json.JSONPropertyIgnore;
|
||||||
import org.nd4j.base.Preconditions;
|
import org.nd4j.base.Preconditions;
|
||||||
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
|
import org.nd4j.jackson.objectmapper.holder.ObjectMapperHolder;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||||
|
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.io.InputStream;
|
import java.io.InputStream;
|
||||||
|
@ -58,6 +61,7 @@ public class PythonTransform implements Transform {
|
||||||
private String outputDict;
|
private String outputDict;
|
||||||
private boolean returnAllVariables;
|
private boolean returnAllVariables;
|
||||||
private boolean setupAndRun = false;
|
private boolean setupAndRun = false;
|
||||||
|
private PythonJob pythonJob;
|
||||||
|
|
||||||
|
|
||||||
@Builder
|
@Builder
|
||||||
|
@ -78,23 +82,12 @@ public class PythonTransform implements Transform {
|
||||||
this.inputs = inputs;
|
this.inputs = inputs;
|
||||||
if (outputs != null)
|
if (outputs != null)
|
||||||
this.outputs = outputs;
|
this.outputs = outputs;
|
||||||
|
|
||||||
if (name != null)
|
if (name != null)
|
||||||
this.name = name;
|
this.name = name;
|
||||||
if (outputDict != null) {
|
if (outputDict != null) {
|
||||||
this.outputDict = outputDict;
|
this.outputDict = outputDict;
|
||||||
this.outputs = new PythonVariables();
|
this.outputs = new PythonVariables();
|
||||||
this.outputs.addDict(outputDict);
|
this.outputs.addDict(outputDict);
|
||||||
|
|
||||||
String helpers;
|
|
||||||
try(InputStream is = new ClassPathResource("pythonexec/serialize_array.py").getInputStream()) {
|
|
||||||
helpers = IOUtils.toString(is, Charset.defaultCharset());
|
|
||||||
|
|
||||||
}catch (IOException e){
|
|
||||||
throw new RuntimeException("Error reading python code");
|
|
||||||
}
|
|
||||||
this.code += "\n\n" + helpers;
|
|
||||||
this.code += "\n" + outputDict + " = __recursive_serialize_dict(" + outputDict + ")";
|
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
@ -114,6 +107,16 @@ public class PythonTransform implements Transform {
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new IllegalStateException(e);
|
throw new IllegalStateException(e);
|
||||||
}
|
}
|
||||||
|
try{
|
||||||
|
pythonJob = PythonJob.builder()
|
||||||
|
.name("a" + UUID.randomUUID().toString().replace("-", "_"))
|
||||||
|
.code(code)
|
||||||
|
.setupRunMode(setupAndRun)
|
||||||
|
.build();
|
||||||
|
}
|
||||||
|
catch(Exception e){
|
||||||
|
throw new IllegalStateException("Error creating python job: " + e);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,43 +161,26 @@ public class PythonTransform implements Transform {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<Writable> map(List<Writable> writables) {
|
public List<Writable> map(List<Writable> writables) {
|
||||||
PythonVariables pyInputs = getPyInputsFromWritables(writables);
|
PythonVariables pyInputs = getPyInputsFromWritables(writables);
|
||||||
Preconditions.checkNotNull(pyInputs, "Inputs must not be null!");
|
Preconditions.checkNotNull(pyInputs, "Inputs must not be null!");
|
||||||
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
if (returnAllVariables) {
|
if (returnAllVariables) {
|
||||||
if (setupAndRun){
|
return getWritablesFromPyOutputs(pythonJob.execAndReturnAllVariables(pyInputs));
|
||||||
return getWritablesFromPyOutputs(PythonExecutioner.execWithSetupRunAndReturnAllVariables(code, pyInputs));
|
|
||||||
}
|
|
||||||
return getWritablesFromPyOutputs(PythonExecutioner.execAndReturnAllVariables(code, pyInputs));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (outputDict != null) {
|
if (outputDict != null) {
|
||||||
if (setupAndRun) {
|
pythonJob.exec(pyInputs, outputs);
|
||||||
PythonExecutioner.execWithSetupAndRun(this, pyInputs);
|
|
||||||
}else{
|
|
||||||
PythonExecutioner.exec(this, pyInputs);
|
|
||||||
}
|
|
||||||
PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict);
|
PythonVariables out = PythonUtils.expandInnerDict(outputs, outputDict);
|
||||||
return getWritablesFromPyOutputs(out);
|
return getWritablesFromPyOutputs(out);
|
||||||
}
|
|
||||||
else {
|
|
||||||
if (setupAndRun) {
|
|
||||||
PythonExecutioner.execWithSetupAndRun(code, pyInputs, outputs);
|
|
||||||
} else {
|
} else {
|
||||||
PythonExecutioner.exec(code, pyInputs, outputs);
|
pythonJob.exec(pyInputs, outputs);
|
||||||
}
|
|
||||||
|
|
||||||
return getWritablesFromPyOutputs(outputs);
|
return getWritablesFromPyOutputs(outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
} catch (Exception e) {
|
||||||
catch (Exception e){
|
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -208,6 +194,7 @@ public class PythonTransform implements Transform {
|
||||||
public String outputColumnName() {
|
public String outputColumnName() {
|
||||||
return outputColumnNames()[0];
|
return outputColumnNames()[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String[] columnNames() {
|
public String[] columnNames() {
|
||||||
return outputs.getVariables();
|
return outputs.getVariables();
|
||||||
|
@ -229,22 +216,19 @@ public class PythonTransform implements Transform {
|
||||||
for (String name : inputs.getVariables()) {
|
for (String name : inputs.getVariables()) {
|
||||||
int colIdx = inputSchema.getIndexOfColumn(name);
|
int colIdx = inputSchema.getIndexOfColumn(name);
|
||||||
Writable w = writables.get(colIdx);
|
Writable w = writables.get(colIdx);
|
||||||
PythonVariables.Type pyType = inputs.getType(name);
|
PythonType pyType = inputs.getType(name);
|
||||||
switch (pyType){
|
switch (pyType.getName()) {
|
||||||
case INT:
|
case INT:
|
||||||
if (w instanceof LongWritable) {
|
if (w instanceof LongWritable) {
|
||||||
ret.addInt(name, ((LongWritable) w).get());
|
ret.addInt(name, ((LongWritable) w).get());
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
ret.addInt(name, ((IntWritable) w).get());
|
ret.addInt(name, ((IntWritable) w).get());
|
||||||
}
|
}
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
if (w instanceof DoubleWritable) {
|
if (w instanceof DoubleWritable) {
|
||||||
ret.addFloat(name, ((DoubleWritable) w).get());
|
ret.addFloat(name, ((DoubleWritable) w).get());
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
ret.addFloat(name, ((FloatWritable) w).get());
|
ret.addFloat(name, ((FloatWritable) w).get());
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
@ -254,6 +238,9 @@ public class PythonTransform implements Transform {
|
||||||
case NDARRAY:
|
case NDARRAY:
|
||||||
ret.addNDArray(name, ((NDArrayWritable) w).get());
|
ret.addNDArray(name, ((NDArrayWritable) w).get());
|
||||||
break;
|
break;
|
||||||
|
case BOOL:
|
||||||
|
ret.addBool(name, ((BooleanWritable) w).get());
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw new RuntimeException("Unsupported input type:" + pyType);
|
throw new RuntimeException("Unsupported input type:" + pyType);
|
||||||
}
|
}
|
||||||
|
@ -269,8 +256,8 @@ public class PythonTransform implements Transform {
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||||
for (int i = 0; i < varNames.length; i++) {
|
for (int i = 0; i < varNames.length; i++) {
|
||||||
String name = varNames[i];
|
String name = varNames[i];
|
||||||
PythonVariables.Type pyType = pyOuts.getType(name);
|
PythonType pyType = pyOuts.getType(name);
|
||||||
switch (pyType){
|
switch (pyType.getName()) {
|
||||||
case INT:
|
case INT:
|
||||||
schemaBuilder.addColumnLong(name);
|
schemaBuilder.addColumnLong(name);
|
||||||
break;
|
break;
|
||||||
|
@ -283,11 +270,14 @@ public class PythonTransform implements Transform {
|
||||||
schemaBuilder.addColumnString(name);
|
schemaBuilder.addColumnString(name);
|
||||||
break;
|
break;
|
||||||
case NDARRAY:
|
case NDARRAY:
|
||||||
NumpyArray arr = pyOuts.getNDArrayValue(name);
|
INDArray arr = pyOuts.getNDArrayValue(name);
|
||||||
schemaBuilder.addColumnNDArray(name, arr.getShape());
|
schemaBuilder.addColumnNDArray(name, arr.shape());
|
||||||
|
break;
|
||||||
|
case BOOL:
|
||||||
|
schemaBuilder.addColumnBoolean(name);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
throw new IllegalStateException("Unable to support type " + pyType.name());
|
throw new IllegalStateException("Unable to support type " + pyType.getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
this.outputSchema = schemaBuilder.build();
|
this.outputSchema = schemaBuilder.build();
|
||||||
|
@ -295,9 +285,9 @@ public class PythonTransform implements Transform {
|
||||||
|
|
||||||
for (int i = 0; i < varNames.length; i++) {
|
for (int i = 0; i < varNames.length; i++) {
|
||||||
String name = varNames[i];
|
String name = varNames[i];
|
||||||
PythonVariables.Type pyType = pyOuts.getType(name);
|
PythonType pyType = pyOuts.getType(name);
|
||||||
|
|
||||||
switch (pyType){
|
switch (pyType.getName()) {
|
||||||
case INT:
|
case INT:
|
||||||
out.add(new LongWritable(pyOuts.getIntValue(name)));
|
out.add(new LongWritable(pyOuts.getIntValue(name)));
|
||||||
break;
|
break;
|
||||||
|
@ -308,8 +298,8 @@ public class PythonTransform implements Transform {
|
||||||
out.add(new Text(pyOuts.getStrValue(name)));
|
out.add(new Text(pyOuts.getStrValue(name)));
|
||||||
break;
|
break;
|
||||||
case NDARRAY:
|
case NDARRAY:
|
||||||
NumpyArray arr = pyOuts.getNDArrayValue(name);
|
INDArray arr = pyOuts.getNDArrayValue(name);
|
||||||
out.add(new NDArrayWritable(arr.getNd4jArray()));
|
out.add(new NDArrayWritable(arr));
|
||||||
break;
|
break;
|
||||||
case DICT:
|
case DICT:
|
||||||
Map<?, ?> dictValue = pyOuts.getDictValue(name);
|
Map<?, ?> dictValue = pyOuts.getDictValue(name);
|
||||||
|
@ -327,21 +317,22 @@ public class PythonTransform implements Transform {
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case LIST:
|
case LIST:
|
||||||
Object[] listValue = pyOuts.getListValue(name);
|
Object[] listValue = pyOuts.getListValue(name).toArray();
|
||||||
try {
|
try {
|
||||||
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue)));
|
out.add(new Text(ObjectMapperHolder.getJsonMapper().writeValueAsString(listValue)));
|
||||||
} catch (JsonProcessingException e) {
|
} catch (JsonProcessingException e) {
|
||||||
throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!");
|
throw new IllegalStateException("Unable to serialize list vlaue " + name + " to json!");
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case BOOL:
|
||||||
|
out.add(new BooleanWritable(pyOuts.getBooleanValue(name)));
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw new IllegalStateException("Unable to support type " + pyType.name());
|
throw new IllegalStateException("Unable to support type " + pyType.getName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
|
@ -0,0 +1,238 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.datavec.python.Python.importModule;
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
*
|
||||||
|
* @param <T> Corresponding Java type for the Python type
|
||||||
|
*/
|
||||||
|
public abstract class PythonType<T> {
|
||||||
|
|
||||||
|
public abstract T toJava(PythonObject pythonObject) throws PythonException;
|
||||||
|
private final TypeName typeName;
|
||||||
|
|
||||||
|
enum TypeName{
|
||||||
|
STR,
|
||||||
|
INT,
|
||||||
|
FLOAT,
|
||||||
|
BOOL,
|
||||||
|
LIST,
|
||||||
|
DICT,
|
||||||
|
NDARRAY,
|
||||||
|
BYTES
|
||||||
|
}
|
||||||
|
private PythonType(TypeName typeName){
|
||||||
|
this.typeName = typeName;
|
||||||
|
}
|
||||||
|
public TypeName getName(){return typeName;}
|
||||||
|
public String toString(){
|
||||||
|
return getName().name();
|
||||||
|
}
|
||||||
|
public static PythonType valueOf(String typeName) throws PythonException{
|
||||||
|
try{
|
||||||
|
typeName.valueOf(typeName);
|
||||||
|
} catch (IllegalArgumentException iae){
|
||||||
|
throw new PythonException("Invalid python type: " + typeName, iae);
|
||||||
|
}
|
||||||
|
try{
|
||||||
|
return (PythonType)PythonType.class.getField(typeName).get(null); // shouldn't fail
|
||||||
|
} catch (Exception e){
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
public static PythonType valueOf(TypeName typeName){
|
||||||
|
try{
|
||||||
|
return valueOf(typeName.name()); // shouldn't fail
|
||||||
|
}catch (PythonException pe){
|
||||||
|
throw new RuntimeException(pe);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Since multiple java types can map to the same python type,
|
||||||
|
* this method "normalizes" all supported incoming objects to T
|
||||||
|
*
|
||||||
|
* @param object object to be converted to type T
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public T convert(Object object) throws PythonException {
|
||||||
|
return (T) object;
|
||||||
|
}
|
||||||
|
|
||||||
|
public static final PythonType<String> STR = new PythonType<String>(TypeName.STR) {
|
||||||
|
@Override
|
||||||
|
public String toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
if (!Python.isinstance(pythonObject, Python.strType())) {
|
||||||
|
throw new PythonException("Expected variable to be str, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String convert(Object object) {
|
||||||
|
return object.toString();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<Long> INT = new PythonType<Long>(TypeName.INT) {
|
||||||
|
@Override
|
||||||
|
public Long toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
if (!Python.isinstance(pythonObject, Python.intType())) {
|
||||||
|
throw new PythonException("Expected variable to be int, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toLong();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Long convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof Number) {
|
||||||
|
return ((Number) object).longValue();
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to Long.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<Double> FLOAT = new PythonType<Double>(TypeName.FLOAT) {
|
||||||
|
@Override
|
||||||
|
public Double toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
if (!Python.isinstance(pythonObject, Python.floatType())) {
|
||||||
|
throw new PythonException("Expected variable to be float, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toDouble();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Double convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof Number) {
|
||||||
|
return ((Number) object).doubleValue();
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to Double.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<Boolean> BOOL = new PythonType<Boolean>(TypeName.BOOL) {
|
||||||
|
@Override
|
||||||
|
public Boolean toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
if (!Python.isinstance(pythonObject, Python.boolType())) {
|
||||||
|
throw new PythonException("Expected variable to be bool, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toBoolean();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Boolean convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof Number) {
|
||||||
|
return ((Number) object).intValue() != 0;
|
||||||
|
} else if (object instanceof Boolean) {
|
||||||
|
return (Boolean) object;
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to Boolean.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<List> LIST = new PythonType<List>(TypeName.LIST) {
|
||||||
|
@Override
|
||||||
|
public List toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
if (!Python.isinstance(pythonObject, Python.listType())) {
|
||||||
|
throw new PythonException("Expected variable to be list, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toList();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof java.util.List) {
|
||||||
|
return (List) object;
|
||||||
|
} else if (object instanceof org.json.JSONArray) {
|
||||||
|
org.json.JSONArray jsonArray = (org.json.JSONArray) object;
|
||||||
|
return jsonArray.toList();
|
||||||
|
|
||||||
|
} else if (object instanceof Object[]) {
|
||||||
|
return Arrays.asList((Object[]) object);
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to List.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<Map> DICT = new PythonType<Map>(TypeName.DICT) {
|
||||||
|
@Override
|
||||||
|
public Map toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
if (!Python.isinstance(pythonObject, Python.dictType())) {
|
||||||
|
throw new PythonException("Expected variable to be dict, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toMap();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Map convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof Map) {
|
||||||
|
return (Map) object;
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to Map.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<INDArray> NDARRAY = new PythonType<INDArray>(TypeName.NDARRAY) {
|
||||||
|
@Override
|
||||||
|
public INDArray toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
PythonObject np = importModule("numpy");
|
||||||
|
if (!Python.isinstance(pythonObject, np.attr("ndarray"), np.attr("generic"))) {
|
||||||
|
throw new PythonException("Expected variable to be numpy.ndarray, but was " + Python.type(pythonObject));
|
||||||
|
}
|
||||||
|
return pythonObject.toNumpy().getNd4jArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public INDArray convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof INDArray) {
|
||||||
|
return (INDArray) object;
|
||||||
|
} else if (object instanceof NumpyArray) {
|
||||||
|
return ((NumpyArray) object).getNd4jArray();
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to INDArray.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
public static final PythonType<BytePointer> BYTES = new PythonType<BytePointer>(TypeName.BYTES) {
|
||||||
|
@Override
|
||||||
|
public BytePointer toJava(PythonObject pythonObject) throws PythonException {
|
||||||
|
return pythonObject.toBytePointer();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytePointer convert(Object object) throws PythonException {
|
||||||
|
if (object instanceof BytePointer) {
|
||||||
|
return (BytePointer) object;
|
||||||
|
} else if (object instanceof Pointer) {
|
||||||
|
return new BytePointer((Pointer) object);
|
||||||
|
}
|
||||||
|
throw new PythonException("Unable to cast " + object + " to BytePointer.");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
|
@ -24,28 +24,30 @@ public class PythonUtils {
|
||||||
* Create a {@link Schema}
|
* Create a {@link Schema}
|
||||||
* from {@link PythonVariables}.
|
* from {@link PythonVariables}.
|
||||||
* Types are mapped to types of the same name.
|
* Types are mapped to types of the same name.
|
||||||
|
*
|
||||||
* @param input the input {@link PythonVariables}
|
* @param input the input {@link PythonVariables}
|
||||||
* @return the output {@link Schema}
|
* @return the output {@link Schema}
|
||||||
*/
|
*/
|
||||||
public static Schema fromPythonVariables(PythonVariables input) {
|
public static Schema fromPythonVariables(PythonVariables input) {
|
||||||
Schema.Builder schemaBuilder = new Schema.Builder();
|
Schema.Builder schemaBuilder = new Schema.Builder();
|
||||||
Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0, "Input must have variables. Found none.");
|
Preconditions.checkState(input.getVariables() != null && input.getVariables().length > 0, "Input must have variables. Found none.");
|
||||||
for(Map.Entry<String,PythonVariables.Type> entry : input.getVars().entrySet()) {
|
for (String varName: input.getVariables()) {
|
||||||
switch(entry.getValue()) {
|
|
||||||
|
switch (input.getType(varName).getName()) {
|
||||||
case INT:
|
case INT:
|
||||||
schemaBuilder.addColumnInteger(entry.getKey());
|
schemaBuilder.addColumnInteger(varName);
|
||||||
break;
|
break;
|
||||||
case STR:
|
case STR:
|
||||||
schemaBuilder.addColumnString(entry.getKey());
|
schemaBuilder.addColumnString(varName);
|
||||||
break;
|
break;
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
schemaBuilder.addColumnFloat(entry.getKey());
|
schemaBuilder.addColumnFloat(varName);
|
||||||
break;
|
break;
|
||||||
case NDARRAY:
|
case NDARRAY:
|
||||||
schemaBuilder.addColumnNDArray(entry.getKey(),null);
|
schemaBuilder.addColumnNDArray(varName, null);
|
||||||
break;
|
break;
|
||||||
case BOOL:
|
case BOOL:
|
||||||
schemaBuilder.addColumn(new BooleanMetaData(entry.getKey()));
|
schemaBuilder.addColumn(new BooleanMetaData(varName));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -56,6 +58,7 @@ public class PythonUtils {
|
||||||
* Create a {@link Schema} from an input
|
* Create a {@link Schema} from an input
|
||||||
* {@link PythonVariables}
|
* {@link PythonVariables}
|
||||||
* Types are mapped to types of the same name
|
* Types are mapped to types of the same name
|
||||||
|
*
|
||||||
* @param input the input schema
|
* @param input the input schema
|
||||||
* @return the output python variables.
|
* @return the output python variables.
|
||||||
*/
|
*/
|
||||||
|
@ -66,24 +69,25 @@ public class PythonUtils {
|
||||||
ColumnType columnType = input.getType(i);
|
ColumnType columnType = input.getType(i);
|
||||||
switch (columnType) {
|
switch (columnType) {
|
||||||
case NDArray:
|
case NDArray:
|
||||||
ret.add(currColumnName, PythonVariables.Type.NDARRAY);
|
ret.add(currColumnName, PythonType.NDARRAY);
|
||||||
break;
|
break;
|
||||||
case Boolean:
|
case Boolean:
|
||||||
ret.add(currColumnName, PythonVariables.Type.BOOL);
|
ret.add(currColumnName, PythonType.BOOL);
|
||||||
break;
|
break;
|
||||||
case Categorical:
|
case Categorical:
|
||||||
case String:
|
case String:
|
||||||
ret.add(currColumnName, PythonVariables.Type.STR);
|
ret.add(currColumnName, PythonType.STR);
|
||||||
break;
|
break;
|
||||||
case Double:
|
case Double:
|
||||||
case Float:
|
case Float:
|
||||||
ret.add(currColumnName, PythonVariables.Type.FLOAT);
|
ret.add(currColumnName, PythonType.FLOAT);
|
||||||
break;
|
break;
|
||||||
case Integer:
|
case Integer:
|
||||||
case Long:
|
case Long:
|
||||||
ret.add(currColumnName, PythonVariables.Type.INT);
|
ret.add(currColumnName, PythonType.INT);
|
||||||
break;
|
break;
|
||||||
case Bytes:
|
case Bytes:
|
||||||
|
ret.add(currColumnName, PythonType.BYTES);
|
||||||
break;
|
break;
|
||||||
case Time:
|
case Time:
|
||||||
throw new UnsupportedOperationException("Unable to process dates with python yet.");
|
throw new UnsupportedOperationException("Unable to process dates with python yet.");
|
||||||
|
@ -92,9 +96,11 @@ public class PythonUtils {
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Convert a {@link Schema}
|
* Convert a {@link Schema}
|
||||||
* to {@link PythonVariables}
|
* to {@link PythonVariables}
|
||||||
|
*
|
||||||
* @param schema the input schema
|
* @param schema the input schema
|
||||||
* @return the output {@link PythonVariables} where each
|
* @return the output {@link PythonVariables} where each
|
||||||
* name in the map is associated with a column name in the schema.
|
* name in the map is associated with a column name in the schema.
|
||||||
|
@ -122,6 +128,9 @@ public class PythonUtils {
|
||||||
case NDArray:
|
case NDArray:
|
||||||
pyVars.addNDArray(colName);
|
pyVars.addNDArray(colName);
|
||||||
break;
|
break;
|
||||||
|
case Boolean:
|
||||||
|
pyVars.addBool(colName);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
throw new Exception("Unsupported python input type: " + colType.toString());
|
throw new Exception("Unsupported python input type: " + colType.toString());
|
||||||
}
|
}
|
||||||
|
@ -136,20 +145,15 @@ public class PythonUtils {
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
if (dtypeName.equals("float64")) {
|
if (dtypeName.equals("float64")) {
|
||||||
dtype = DataType.DOUBLE;
|
dtype = DataType.DOUBLE;
|
||||||
}
|
} else if (dtypeName.equals("float32")) {
|
||||||
else if (dtypeName.equals("float32")){
|
|
||||||
dtype = DataType.FLOAT;
|
dtype = DataType.FLOAT;
|
||||||
}
|
} else if (dtypeName.equals("int16")) {
|
||||||
else if (dtypeName.equals("int16")){
|
|
||||||
dtype = DataType.SHORT;
|
dtype = DataType.SHORT;
|
||||||
}
|
} else if (dtypeName.equals("int32")) {
|
||||||
else if (dtypeName.equals("int32")){
|
|
||||||
dtype = DataType.INT;
|
dtype = DataType.INT;
|
||||||
}
|
} else if (dtypeName.equals("int64")) {
|
||||||
else if (dtypeName.equals("int64")){
|
|
||||||
dtype = DataType.LONG;
|
dtype = DataType.LONG;
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||||
}
|
}
|
||||||
List shapeList = (List) map.get("shape");
|
List shapeList = (List) map.get("shape");
|
||||||
|
@ -164,7 +168,7 @@ public class PythonUtils {
|
||||||
stride[i] = (Long) strideList.get(i);
|
stride[i] = (Long) strideList.get(i);
|
||||||
}
|
}
|
||||||
long address = (Long) map.get("address");
|
long address = (Long) map.get("address");
|
||||||
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
|
NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true);
|
||||||
return numpyArray;
|
return numpyArray;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,34 +183,26 @@ public class PythonUtils {
|
||||||
if (map.containsKey("_is_numpy_array")) {
|
if (map.containsKey("_is_numpy_array")) {
|
||||||
pyvars2.addNDArray(subkey, mapToNumpyArray(map));
|
pyvars2.addNDArray(subkey, mapToNumpyArray(map));
|
||||||
|
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
pyvars2.addDict(subkey, (Map) value);
|
pyvars2.addDict(subkey, (Map) value);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
} else if (value instanceof List) {
|
||||||
else if (value instanceof List){
|
|
||||||
pyvars2.addList(subkey, ((List) value).toArray());
|
pyvars2.addList(subkey, ((List) value).toArray());
|
||||||
}
|
} else if (value instanceof String) {
|
||||||
else if (value instanceof String){
|
|
||||||
System.out.println((String) value);
|
System.out.println((String) value);
|
||||||
pyvars2.addStr(subkey, (String) value);
|
pyvars2.addStr(subkey, (String) value);
|
||||||
}
|
} else if (value instanceof Integer || value instanceof Long) {
|
||||||
else if (value instanceof Integer || value instanceof Long) {
|
|
||||||
Number number = (Number) value;
|
Number number = (Number) value;
|
||||||
pyvars2.addInt(subkey, number.intValue());
|
pyvars2.addInt(subkey, number.intValue());
|
||||||
}
|
} else if (value instanceof Float || value instanceof Double) {
|
||||||
else if (value instanceof Float || value instanceof Double) {
|
|
||||||
Number number = (Number) value;
|
Number number = (Number) value;
|
||||||
pyvars2.addFloat(subkey, number.doubleValue());
|
pyvars2.addFloat(subkey, number.doubleValue());
|
||||||
}
|
} else if (value instanceof NumpyArray) {
|
||||||
else if (value instanceof NumpyArray){
|
|
||||||
pyvars2.addNDArray(subkey, (NumpyArray) value);
|
pyvars2.addNDArray(subkey, (NumpyArray) value);
|
||||||
}
|
} else if (value == null) {
|
||||||
else if (value == null){
|
|
||||||
pyvars2.addStr(subkey, "None"); // FixMe
|
pyvars2.addStr(subkey, "None"); // FixMe
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
throw new RuntimeException("Unsupported type!" + value);
|
throw new RuntimeException("Unsupported type!" + value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -233,15 +229,15 @@ public class PythonUtils {
|
||||||
JSONObject jsonobj2 = (JSONObject) value;
|
JSONObject jsonobj2 = (JSONObject) value;
|
||||||
if (jsonobj2.has("_is_numpy_array")) {
|
if (jsonobj2.has("_is_numpy_array")) {
|
||||||
value = jsonToNumpyArray(jsonobj2);
|
value = jsonToNumpyArray(jsonobj2);
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
value = toMap(jsonobj2);
|
value = toMap(jsonobj2);
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
map.put(key, value);
|
map.put(key, value);
|
||||||
} return map;
|
}
|
||||||
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -270,35 +266,30 @@ public class PythonUtils {
|
||||||
DataType dtype;
|
DataType dtype;
|
||||||
if (dtypeName.equals("float64")) {
|
if (dtypeName.equals("float64")) {
|
||||||
dtype = DataType.DOUBLE;
|
dtype = DataType.DOUBLE;
|
||||||
}
|
} else if (dtypeName.equals("float32")) {
|
||||||
else if (dtypeName.equals("float32")){
|
|
||||||
dtype = DataType.FLOAT;
|
dtype = DataType.FLOAT;
|
||||||
}
|
} else if (dtypeName.equals("int16")) {
|
||||||
else if (dtypeName.equals("int16")){
|
|
||||||
dtype = DataType.SHORT;
|
dtype = DataType.SHORT;
|
||||||
}
|
} else if (dtypeName.equals("int32")) {
|
||||||
else if (dtypeName.equals("int32")){
|
|
||||||
dtype = DataType.INT;
|
dtype = DataType.INT;
|
||||||
}
|
} else if (dtypeName.equals("int64")) {
|
||||||
else if (dtypeName.equals("int64")){
|
|
||||||
dtype = DataType.LONG;
|
dtype = DataType.LONG;
|
||||||
}
|
} else {
|
||||||
else{
|
|
||||||
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
throw new RuntimeException("Unsupported array type " + dtypeName + ".");
|
||||||
}
|
}
|
||||||
List shapeList = (List)map.get("shape");
|
List shapeList = map.getJSONArray("shape").toList();
|
||||||
long[] shape = new long[shapeList.size()];
|
long[] shape = new long[shapeList.size()];
|
||||||
for (int i = 0; i < shape.length; i++) {
|
for (int i = 0; i < shape.length; i++) {
|
||||||
shape[i] = (Long)shapeList.get(i);
|
shape[i] = ((Number) shapeList.get(i)).longValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
List strideList = (List)map.get("shape");
|
List strideList = map.getJSONArray("shape").toList();
|
||||||
long[] stride = new long[strideList.size()];
|
long[] stride = new long[strideList.size()];
|
||||||
for (int i = 0; i < stride.length; i++) {
|
for (int i = 0; i < stride.length; i++) {
|
||||||
stride[i] = (Long)strideList.get(i);
|
stride[i] = ((Number) strideList.get(i)).longValue();
|
||||||
}
|
}
|
||||||
long address = (Long)map.get("address");
|
long address = ((Number) map.get("address")).longValue();
|
||||||
NumpyArray numpyArray = new NumpyArray(address, shape, stride, true,dtype);
|
NumpyArray numpyArray = new NumpyArray(address, shape, stride, dtype, true);
|
||||||
return numpyArray;
|
return numpyArray;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,13 +17,19 @@
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
|
import org.bytedeco.javacpp.Pointer;
|
||||||
import org.json.JSONObject;
|
import org.json.JSONObject;
|
||||||
import org.json.JSONArray;
|
import org.json.JSONArray;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.nativeblas.NativeOpsHolder;
|
||||||
|
|
||||||
import java.io.Serializable;
|
import java.io.Serializable;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Holds python variable names, types and values.
|
* Holds python variable names, types and values.
|
||||||
* Also handles mapping from java types to python types.
|
* Also handles mapping from java types to python types.
|
||||||
|
@ -34,40 +40,30 @@ import java.util.*;
|
||||||
@lombok.Data
|
@lombok.Data
|
||||||
public class PythonVariables implements java.io.Serializable {
|
public class PythonVariables implements java.io.Serializable {
|
||||||
|
|
||||||
public enum Type{
|
|
||||||
BOOL,
|
|
||||||
STR,
|
|
||||||
INT,
|
|
||||||
FLOAT,
|
|
||||||
NDARRAY,
|
|
||||||
LIST,
|
|
||||||
FILE,
|
|
||||||
DICT
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private java.util.Map<String, String> strVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, String> strVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, Long> intVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, Long> intVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, Double> floatVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, Double> floatVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, Boolean> boolVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, Boolean> boolVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, NumpyArray> ndVars = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, INDArray> ndVars = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, Object[]> listVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, List> listVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, String> fileVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, BytePointer> bytesVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, java.util.Map<?, ?>> dictVariables = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, java.util.Map<?, ?>> dictVariables = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<String, Type> vars = new java.util.LinkedHashMap<>();
|
private java.util.Map<String, PythonType.TypeName> vars = new java.util.LinkedHashMap<>();
|
||||||
private java.util.Map<Type, java.util.Map> maps = new java.util.LinkedHashMap<>();
|
private java.util.Map<PythonType.TypeName, java.util.Map> maps = new java.util.LinkedHashMap<>();
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a copy of the variable
|
* Returns a copy of the variable
|
||||||
* schema in this array without the values
|
* schema in this array without the values
|
||||||
|
*
|
||||||
* @return an empty variables clone
|
* @return an empty variables clone
|
||||||
* with no values
|
* with no values
|
||||||
*/
|
*/
|
||||||
public PythonVariables copySchema() {
|
public PythonVariables copySchema() {
|
||||||
PythonVariables ret = new PythonVariables();
|
PythonVariables ret = new PythonVariables();
|
||||||
for (String varName : getVariables()) {
|
for (String varName : getVariables()) {
|
||||||
Type type = getType(varName);
|
PythonType type = getType(varName);
|
||||||
ret.add(varName, type);
|
ret.add(varName, type);
|
||||||
}
|
}
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -77,21 +73,19 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
public PythonVariables() {
|
public PythonVariables() {
|
||||||
maps.put(PythonVariables.Type.BOOL, boolVariables);
|
maps.put(PythonType.TypeName.BOOL, boolVariables);
|
||||||
maps.put(PythonVariables.Type.STR, strVariables);
|
maps.put(PythonType.TypeName.STR, strVariables);
|
||||||
maps.put(PythonVariables.Type.INT, intVariables);
|
maps.put(PythonType.TypeName.INT, intVariables);
|
||||||
maps.put(PythonVariables.Type.FLOAT, floatVariables);
|
maps.put(PythonType.TypeName.FLOAT, floatVariables);
|
||||||
maps.put(PythonVariables.Type.NDARRAY, ndVars);
|
maps.put(PythonType.TypeName.NDARRAY, ndVars);
|
||||||
maps.put(PythonVariables.Type.LIST, listVariables);
|
maps.put(PythonType.TypeName.LIST, listVariables);
|
||||||
maps.put(PythonVariables.Type.FILE, fileVariables);
|
maps.put(PythonType.TypeName.DICT, dictVariables);
|
||||||
maps.put(PythonVariables.Type.DICT, dictVariables);
|
maps.put(PythonType.TypeName.BYTES, bytesVariables);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @return true if there are no variables.
|
* @return true if there are no variables.
|
||||||
*/
|
*/
|
||||||
public boolean isEmpty() {
|
public boolean isEmpty() {
|
||||||
|
@ -100,12 +94,11 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name Name of the variable
|
* @param name Name of the variable
|
||||||
* @param type Type of the variable
|
* @param type Type of the variable
|
||||||
*/
|
*/
|
||||||
public void add(String name, Type type){
|
public void add(String name, PythonType type) {
|
||||||
switch (type){
|
switch (type.getName()) {
|
||||||
case BOOL:
|
case BOOL:
|
||||||
addBool(name);
|
addBool(name);
|
||||||
break;
|
break;
|
||||||
|
@ -124,21 +117,21 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
case LIST:
|
case LIST:
|
||||||
addList(name);
|
addList(name);
|
||||||
break;
|
break;
|
||||||
case FILE:
|
|
||||||
addFile(name);
|
|
||||||
break;
|
|
||||||
case DICT:
|
case DICT:
|
||||||
addDict(name);
|
addDict(name);
|
||||||
|
break;
|
||||||
|
case BYTES:
|
||||||
|
addBytes(name);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name name of the variable
|
* @param name name of the variable
|
||||||
* @param type type of the variable
|
* @param type type of the variable
|
||||||
* @param value value of the variable (must be instance of expected type)
|
* @param value value of the variable (must be instance of expected type)
|
||||||
*/
|
*/
|
||||||
public void add(String name, Type type, Object value) {
|
public void add(String name, PythonType type, Object value) throws PythonException {
|
||||||
add(name, type);
|
add(name, type);
|
||||||
setValue(name, value);
|
setValue(name, value);
|
||||||
}
|
}
|
||||||
|
@ -148,10 +141,11 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addDict(String name) {
|
public void addDict(String name) {
|
||||||
vars.put(name, PythonVariables.Type.DICT);
|
vars.put(name, PythonType.TypeName.DICT);
|
||||||
dictVariables.put(name, null);
|
dictVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -159,10 +153,11 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addBool(String name) {
|
public void addBool(String name) {
|
||||||
vars.put(name, PythonVariables.Type.BOOL);
|
vars.put(name, PythonType.TypeName.BOOL);
|
||||||
boolVariables.put(name, null);
|
boolVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -170,10 +165,11 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addStr(String name) {
|
public void addStr(String name) {
|
||||||
vars.put(name, PythonVariables.Type.STR);
|
vars.put(name, PythonType.TypeName.STR);
|
||||||
strVariables.put(name, null);
|
strVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -181,10 +177,11 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addInt(String name) {
|
public void addInt(String name) {
|
||||||
vars.put(name, PythonVariables.Type.INT);
|
vars.put(name, PythonType.TypeName.INT);
|
||||||
intVariables.put(name, null);
|
intVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -192,10 +189,11 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addFloat(String name) {
|
public void addFloat(String name) {
|
||||||
vars.put(name, PythonVariables.Type.FLOAT);
|
vars.put(name, PythonType.TypeName.FLOAT);
|
||||||
floatVariables.put(name, null);
|
floatVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,10 +201,11 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addNDArray(String name) {
|
public void addNDArray(String name) {
|
||||||
vars.put(name, PythonVariables.Type.NDARRAY);
|
vars.put(name, PythonType.TypeName.NDARRAY);
|
||||||
ndVars.put(name, null);
|
ndVars.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -214,87 +213,83 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
*/
|
*/
|
||||||
public void addList(String name) {
|
public void addList(String name) {
|
||||||
vars.put(name, PythonVariables.Type.LIST);
|
vars.put(name, PythonType.TypeName.LIST);
|
||||||
listVariables.put(name, null);
|
listVariables.put(name, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a null variable to
|
|
||||||
* the set of variables
|
|
||||||
* to describe the type but no value
|
|
||||||
* @param name the field to add
|
|
||||||
*/
|
|
||||||
public void addFile(String name){
|
|
||||||
vars.put(name, PythonVariables.Type.FILE);
|
|
||||||
fileVariables.put(name, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a boolean variable to
|
* Add a boolean variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addBool(String name, boolean value) {
|
public void addBool(String name, boolean value) {
|
||||||
vars.put(name, PythonVariables.Type.BOOL);
|
vars.put(name, PythonType.TypeName.BOOL);
|
||||||
boolVariables.put(name, value);
|
boolVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a string variable to
|
* Add a string variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addStr(String name, String value) {
|
public void addStr(String name, String value) {
|
||||||
vars.put(name, PythonVariables.Type.STR);
|
vars.put(name, PythonType.TypeName.STR);
|
||||||
strVariables.put(name, value);
|
strVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add an int variable to
|
* Add an int variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addInt(String name, int value) {
|
public void addInt(String name, int value) {
|
||||||
vars.put(name, PythonVariables.Type.INT);
|
vars.put(name, PythonType.TypeName.INT);
|
||||||
intVariables.put(name, (long) value);
|
intVariables.put(name, (long) value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a long variable to
|
* Add a long variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addInt(String name, long value) {
|
public void addInt(String name, long value) {
|
||||||
vars.put(name, PythonVariables.Type.INT);
|
vars.put(name, PythonType.TypeName.INT);
|
||||||
intVariables.put(name, value);
|
intVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a double variable to
|
* Add a double variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addFloat(String name, double value) {
|
public void addFloat(String name, double value) {
|
||||||
vars.put(name, PythonVariables.Type.FLOAT);
|
vars.put(name, PythonType.TypeName.FLOAT);
|
||||||
floatVariables.put(name, value);
|
floatVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a float variable to
|
* Add a float variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addFloat(String name, float value) {
|
public void addFloat(String name, float value) {
|
||||||
vars.put(name, PythonVariables.Type.FLOAT);
|
vars.put(name, PythonType.TypeName.FLOAT);
|
||||||
floatVariables.put(name, (double) value);
|
floatVariables.put(name, (double) value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -302,11 +297,25 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
|
*
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addNDArray(String name, NumpyArray value) {
|
public void addNDArray(String name, NumpyArray value) {
|
||||||
vars.put(name, PythonVariables.Type.NDARRAY);
|
vars.put(name, PythonType.TypeName.NDARRAY);
|
||||||
|
ndVars.put(name, value.getNd4jArray());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a null variable to
|
||||||
|
* the set of variables
|
||||||
|
* to describe the type but no value
|
||||||
|
*
|
||||||
|
* @param name the field to add
|
||||||
|
* @param value the value to add
|
||||||
|
*/
|
||||||
|
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
|
||||||
|
vars.put(name, PythonType.TypeName.NDARRAY);
|
||||||
ndVars.put(name, value);
|
ndVars.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -314,117 +323,63 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
* @param name the field to add
|
*
|
||||||
* @param value the value to add
|
|
||||||
*/
|
|
||||||
public void addNDArray(String name, org.nd4j.linalg.api.ndarray.INDArray value) {
|
|
||||||
vars.put(name, PythonVariables.Type.NDARRAY);
|
|
||||||
ndVars.put(name, new NumpyArray(value));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a null variable to
|
|
||||||
* the set of variables
|
|
||||||
* to describe the type but no value
|
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addList(String name, Object[] value) {
|
public void addList(String name, Object[] value) {
|
||||||
vars.put(name, PythonVariables.Type.LIST);
|
vars.put(name, PythonType.TypeName.LIST);
|
||||||
listVariables.put(name, value);
|
listVariables.put(name, Arrays.asList(value));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Add a null variable to
|
* Add a null variable to
|
||||||
* the set of variables
|
* the set of variables
|
||||||
* to describe the type but no value
|
* to describe the type but no value
|
||||||
* @param name the field to add
|
*
|
||||||
* @param value the value to add
|
|
||||||
*/
|
|
||||||
public void addFile(String name, String value) {
|
|
||||||
vars.put(name, PythonVariables.Type.FILE);
|
|
||||||
fileVariables.put(name, value);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a null variable to
|
|
||||||
* the set of variables
|
|
||||||
* to describe the type but no value
|
|
||||||
* @param name the field to add
|
* @param name the field to add
|
||||||
* @param value the value to add
|
* @param value the value to add
|
||||||
*/
|
*/
|
||||||
public void addDict(String name, java.util.Map value) {
|
public void addDict(String name, java.util.Map value) {
|
||||||
vars.put(name, PythonVariables.Type.DICT);
|
vars.put(name, PythonType.TypeName.DICT);
|
||||||
dictVariables.put(name, value);
|
dictVariables.put(name, value);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public void addBytes(String name){
|
||||||
|
vars.put(name, PythonType.TypeName.BYTES);
|
||||||
|
bytesVariables.put(name, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addBytes(String name, BytePointer value){
|
||||||
|
vars.put(name, PythonType.TypeName.BYTES);
|
||||||
|
bytesVariables.put(name, value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// public void addBytes(String name, ByteBuffer value){
|
||||||
|
// Pointer ptr = NativeOpsHolder.getInstance().getDeviceNativeOps().pointerForAddress((value.address());
|
||||||
|
// BytePointer bp = new BytePointer(ptr);
|
||||||
|
// addBytes(name, bp);
|
||||||
|
// }
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name name of the variable
|
* @param name name of the variable
|
||||||
* @param value new value for the variable
|
* @param value new value for the variable
|
||||||
*/
|
*/
|
||||||
public void setValue(String name, Object value) {
|
public void setValue(String name, Object value) throws PythonException {
|
||||||
Type type = vars.get(name);
|
PythonType.TypeName type = vars.get(name);
|
||||||
if (type == PythonVariables.Type.BOOL){
|
maps.get(type).put(name, PythonType.valueOf(type).convert(value));
|
||||||
boolVariables.put(name, (Boolean)value);
|
|
||||||
}
|
|
||||||
else if (type == PythonVariables.Type.INT){
|
|
||||||
Number number = (Number) value;
|
|
||||||
intVariables.put(name, number.longValue());
|
|
||||||
}
|
|
||||||
else if (type == PythonVariables.Type.FLOAT){
|
|
||||||
Number number = (Number) value;
|
|
||||||
floatVariables.put(name, number.doubleValue());
|
|
||||||
}
|
|
||||||
else if (type == PythonVariables.Type.NDARRAY){
|
|
||||||
if (value instanceof NumpyArray){
|
|
||||||
ndVars.put(name, (NumpyArray)value);
|
|
||||||
}
|
|
||||||
else if (value instanceof org.nd4j.linalg.api.ndarray.INDArray) {
|
|
||||||
ndVars.put(name, new NumpyArray((org.nd4j.linalg.api.ndarray.INDArray) value));
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
throw new RuntimeException("Unsupported type: " + value.getClass().toString());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if (type == PythonVariables.Type.LIST) {
|
|
||||||
if (value instanceof java.util.List) {
|
|
||||||
value = ((java.util.List) value).toArray();
|
|
||||||
listVariables.put(name, (Object[]) value);
|
|
||||||
}
|
|
||||||
else if(value instanceof org.json.JSONArray) {
|
|
||||||
org.json.JSONArray jsonArray = (org.json.JSONArray) value;
|
|
||||||
Object[] copyArr = new Object[jsonArray.length()];
|
|
||||||
for(int i = 0; i < copyArr.length; i++) {
|
|
||||||
copyArr[i] = jsonArray.get(i);
|
|
||||||
}
|
|
||||||
listVariables.put(name, copyArr);
|
|
||||||
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
listVariables.put(name, (Object[]) value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else if(type == PythonVariables.Type.DICT) {
|
|
||||||
dictVariables.put(name,(java.util.Map<?,?>) value);
|
|
||||||
}
|
|
||||||
else if (type == PythonVariables.Type.FILE){
|
|
||||||
fileVariables.put(name, (String)value);
|
|
||||||
}
|
|
||||||
else{
|
|
||||||
strVariables.put(name, (String)value);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Do a general object lookup.
|
* Do a general object lookup.
|
||||||
* The look up will happen relative to the {@link Type}
|
* The look up will happen relative to the {@link PythonType}
|
||||||
* of variable is described in the
|
* of variable is described in the
|
||||||
|
*
|
||||||
* @param name the name of the variable to get
|
* @param name the name of the variable to get
|
||||||
* @return teh value for the variable with the given name
|
* @return teh value for the variable with the given name
|
||||||
*/
|
*/
|
||||||
public Object getValue(String name) {
|
public Object getValue(String name) {
|
||||||
Type type = vars.get(name);
|
PythonType.TypeName type = vars.get(name);
|
||||||
java.util.Map map = maps.get(type);
|
java.util.Map map = maps.get(type);
|
||||||
return map.get(name);
|
return map.get(name);
|
||||||
}
|
}
|
||||||
|
@ -432,6 +387,7 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns a boolean variable with the given name.
|
* Returns a boolean variable with the given name.
|
||||||
|
*
|
||||||
* @param name the variable name to get the value for
|
* @param name the variable name to get the value for
|
||||||
* @return the retrieved boolean value
|
* @return the retrieved boolean value
|
||||||
*/
|
*/
|
||||||
|
@ -440,7 +396,6 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the dictionary value
|
* @return the dictionary value
|
||||||
*/
|
*/
|
||||||
|
@ -449,7 +404,7 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
/**
|
* /**
|
||||||
*
|
*
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the string value
|
* @return the string value
|
||||||
|
@ -459,7 +414,6 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the long value
|
* @return the long value
|
||||||
*/
|
*/
|
||||||
|
@ -468,7 +422,6 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the float value
|
* @return the float value
|
||||||
*/
|
*/
|
||||||
|
@ -477,43 +430,44 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the numpy array value
|
* @return the numpy array value
|
||||||
*/
|
*/
|
||||||
public NumpyArray getNDArrayValue(String name){
|
public INDArray getNDArrayValue(String name) {
|
||||||
return ndVars.get(name);
|
return ndVars.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the list value as an object array
|
* @return the list value as an object array
|
||||||
*/
|
*/
|
||||||
public Object[] getListValue(String name){
|
public List getListValue(String name) {
|
||||||
return listVariables.get(name);
|
return listVariables.get(name);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
*
|
|
||||||
* @param name the variable name
|
* @param name the variable name
|
||||||
* @return the value of the given file name
|
* @return the bytes value as a BytePointer
|
||||||
*/
|
*/
|
||||||
public String getFileValue(String name){
|
public BytePointer getBytesValue(String name){return bytesVariables.get(name);}
|
||||||
return fileVariables.get(name);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the type for the given variable name
|
* Returns the type for the given variable name
|
||||||
|
*
|
||||||
* @param name the name of the variable to get the type for
|
* @param name the name of the variable to get the type for
|
||||||
* @return the type for the given variable
|
* @return the type for the given variable
|
||||||
*/
|
*/
|
||||||
public Type getType(String name){
|
public PythonType getType(String name){
|
||||||
return vars.get(name);
|
try{
|
||||||
|
return PythonType.valueOf(vars.get(name)); // will never fail
|
||||||
|
}catch (Exception e)
|
||||||
|
{
|
||||||
|
throw new RuntimeException(e);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get all the variables present as a string array
|
* Get all the variables present as a string array
|
||||||
|
*
|
||||||
* @return the variable names for this variable sset
|
* @return the variable names for this variable sset
|
||||||
*/
|
*/
|
||||||
public String[] getVariables() {
|
public String[] getVariables() {
|
||||||
|
@ -524,6 +478,7 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This variables set as its json representation (an array of json objects)
|
* This variables set as its json representation (an array of json objects)
|
||||||
|
*
|
||||||
* @return the json array output
|
* @return the json array output
|
||||||
*/
|
*/
|
||||||
public org.json.JSONArray toJSON() {
|
public org.json.JSONArray toJSON() {
|
||||||
|
@ -542,13 +497,14 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
* Create a schema from a map.
|
* Create a schema from a map.
|
||||||
* This is an empty PythonVariables
|
* This is an empty PythonVariables
|
||||||
* that just contains names and types with no values
|
* that just contains names and types with no values
|
||||||
|
*
|
||||||
* @param inputTypes the input types to convert
|
* @param inputTypes the input types to convert
|
||||||
* @return the schema from the given map
|
* @return the schema from the given map
|
||||||
*/
|
*/
|
||||||
public static PythonVariables schemaFromMap(java.util.Map<String,String> inputTypes) {
|
public static PythonVariables schemaFromMap(java.util.Map<String, String> inputTypes) throws Exception{
|
||||||
PythonVariables ret = new PythonVariables();
|
PythonVariables ret = new PythonVariables();
|
||||||
for (java.util.Map.Entry<String, String> entry : inputTypes.entrySet()) {
|
for (java.util.Map.Entry<String, String> entry : inputTypes.entrySet()) {
|
||||||
ret.add(entry.getKey(), PythonVariables.Type.valueOf(entry.getValue()));
|
ret.add(entry.getKey(), PythonType.valueOf(entry.getValue()));
|
||||||
}
|
}
|
||||||
|
|
||||||
return ret;
|
return ret;
|
||||||
|
@ -557,6 +513,7 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
/**
|
/**
|
||||||
* Get the python variable state relative to the
|
* Get the python variable state relative to the
|
||||||
* input json array
|
* input json array
|
||||||
|
*
|
||||||
* @param jsonArray the input json array
|
* @param jsonArray the input json array
|
||||||
* @return the python variables based on the input json array
|
* @return the python variables based on the input json array
|
||||||
*/
|
*/
|
||||||
|
@ -566,30 +523,7 @@ public class PythonVariables implements java.io.Serializable {
|
||||||
org.json.JSONObject input = (org.json.JSONObject) jsonArray.get(i);
|
org.json.JSONObject input = (org.json.JSONObject) jsonArray.get(i);
|
||||||
String varName = (String) input.get("name");
|
String varName = (String) input.get("name");
|
||||||
String varType = (String) input.get("type");
|
String varType = (String) input.get("type");
|
||||||
if (varType.equals("BOOL")) {
|
pyvars.maps.get(PythonType.TypeName.valueOf(varType)).put(varName, null);
|
||||||
pyvars.addBool(varName);
|
|
||||||
}
|
|
||||||
else if (varType.equals("INT")) {
|
|
||||||
pyvars.addInt(varName);
|
|
||||||
}
|
|
||||||
else if (varType.equals("FlOAT")){
|
|
||||||
pyvars.addFloat(varName);
|
|
||||||
}
|
|
||||||
else if (varType.equals("STR")) {
|
|
||||||
pyvars.addStr(varName);
|
|
||||||
}
|
|
||||||
else if (varType.equals("LIST")) {
|
|
||||||
pyvars.addList(varName);
|
|
||||||
}
|
|
||||||
else if (varType.equals("FILE")){
|
|
||||||
pyvars.addFile(varName);
|
|
||||||
}
|
|
||||||
else if (varType.equals("NDARRAY")) {
|
|
||||||
pyvars.addNDArray(varName);
|
|
||||||
}
|
|
||||||
else if(varType.equals("DICT")) {
|
|
||||||
pyvars.addDict(varName);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return pyvars;
|
return pyvars;
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
#See: https://stackoverflow.com/questions/3543833/how-do-i-clear-all-variables-in-the-middle-of-a-python-script
|
|
||||||
import sys
|
|
||||||
this = sys.modules[__name__]
|
|
||||||
for n in dir():
|
|
||||||
if n[0]!='_': delattr(this, n)
|
|
|
@ -1 +0,0 @@
|
||||||
loc = {}
|
|
|
@ -1,20 +0,0 @@
|
||||||
|
|
||||||
def __is_numpy_array(x):
|
|
||||||
return str(type(x))== "<class 'numpy.ndarray'>"
|
|
||||||
|
|
||||||
def maybe_serialize_ndarray_metadata(x):
|
|
||||||
return serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_ndarray_metadata(x):
|
|
||||||
return {"address": x.__array_interface__['data'][0],
|
|
||||||
"shape": x.shape,
|
|
||||||
"strides": x.strides,
|
|
||||||
"dtype": str(x.dtype),
|
|
||||||
"_is_numpy_array": True} if __is_numpy_array(x) else x
|
|
||||||
|
|
||||||
|
|
||||||
def is_json_ready(key, value):
|
|
||||||
return key is not 'f2' and not inspect.ismodule(value) \
|
|
||||||
and not hasattr(value, '__call__')
|
|
||||||
|
|
|
@ -1,202 +0,0 @@
|
||||||
#patch
|
|
||||||
|
|
||||||
"""Implementation of __array_function__ overrides from NEP-18."""
|
|
||||||
import collections
|
|
||||||
import functools
|
|
||||||
import os
|
|
||||||
|
|
||||||
from numpy.core._multiarray_umath import (
|
|
||||||
add_docstring, implement_array_function, _get_implementing_args)
|
|
||||||
from numpy.compat._inspect import getargspec
|
|
||||||
|
|
||||||
|
|
||||||
ENABLE_ARRAY_FUNCTION = bool(
|
|
||||||
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
|
|
||||||
|
|
||||||
|
|
||||||
ARRAY_FUNCTION_ENABLED = ENABLE_ARRAY_FUNCTION # backward compat
|
|
||||||
|
|
||||||
|
|
||||||
_add_docstring = add_docstring
|
|
||||||
|
|
||||||
|
|
||||||
def add_docstring(*args):
|
|
||||||
try:
|
|
||||||
_add_docstring(*args)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
add_docstring(
|
|
||||||
implement_array_function,
|
|
||||||
"""
|
|
||||||
Implement a function with checks for __array_function__ overrides.
|
|
||||||
|
|
||||||
All arguments are required, and can only be passed by position.
|
|
||||||
|
|
||||||
Arguments
|
|
||||||
---------
|
|
||||||
implementation : function
|
|
||||||
Function that implements the operation on NumPy array without
|
|
||||||
overrides when called like ``implementation(*args, **kwargs)``.
|
|
||||||
public_api : function
|
|
||||||
Function exposed by NumPy's public API originally called like
|
|
||||||
``public_api(*args, **kwargs)`` on which arguments are now being
|
|
||||||
checked.
|
|
||||||
relevant_args : iterable
|
|
||||||
Iterable of arguments to check for __array_function__ methods.
|
|
||||||
args : tuple
|
|
||||||
Arbitrary positional arguments originally passed into ``public_api``.
|
|
||||||
kwargs : dict
|
|
||||||
Arbitrary keyword arguments originally passed into ``public_api``.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Result from calling ``implementation()`` or an ``__array_function__``
|
|
||||||
method, as appropriate.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
TypeError : if no implementation is found.
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
# exposed for testing purposes; used internally by implement_array_function
|
|
||||||
add_docstring(
|
|
||||||
_get_implementing_args,
|
|
||||||
"""
|
|
||||||
Collect arguments on which to call __array_function__.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
relevant_args : iterable of array-like
|
|
||||||
Iterable of possibly array-like arguments to check for
|
|
||||||
__array_function__ methods.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Sequence of arguments with __array_function__ methods, in the order in
|
|
||||||
which they should be called.
|
|
||||||
""")
|
|
||||||
|
|
||||||
|
|
||||||
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
|
|
||||||
|
|
||||||
|
|
||||||
def verify_matching_signatures(implementation, dispatcher):
|
|
||||||
"""Verify that a dispatcher function has the right signature."""
|
|
||||||
implementation_spec = ArgSpec(*getargspec(implementation))
|
|
||||||
dispatcher_spec = ArgSpec(*getargspec(dispatcher))
|
|
||||||
|
|
||||||
if (implementation_spec.args != dispatcher_spec.args or
|
|
||||||
implementation_spec.varargs != dispatcher_spec.varargs or
|
|
||||||
implementation_spec.keywords != dispatcher_spec.keywords or
|
|
||||||
(bool(implementation_spec.defaults) !=
|
|
||||||
bool(dispatcher_spec.defaults)) or
|
|
||||||
(implementation_spec.defaults is not None and
|
|
||||||
len(implementation_spec.defaults) !=
|
|
||||||
len(dispatcher_spec.defaults))):
|
|
||||||
raise RuntimeError('implementation and dispatcher for %s have '
|
|
||||||
'different function signatures' % implementation)
|
|
||||||
|
|
||||||
if implementation_spec.defaults is not None:
|
|
||||||
if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
|
|
||||||
raise RuntimeError('dispatcher functions can only use None for '
|
|
||||||
'default argument values')
|
|
||||||
|
|
||||||
|
|
||||||
def set_module(module):
|
|
||||||
"""Decorator for overriding __module__ on a function or class.
|
|
||||||
|
|
||||||
Example usage::
|
|
||||||
|
|
||||||
@set_module('numpy')
|
|
||||||
def example():
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert example.__module__ == 'numpy'
|
|
||||||
"""
|
|
||||||
def decorator(func):
|
|
||||||
if module is not None:
|
|
||||||
func.__module__ = module
|
|
||||||
return func
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def array_function_dispatch(dispatcher, module=None, verify=True,
|
|
||||||
docs_from_dispatcher=False):
|
|
||||||
"""Decorator for adding dispatch with the __array_function__ protocol.
|
|
||||||
|
|
||||||
See NEP-18 for example usage.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
dispatcher : callable
|
|
||||||
Function that when called like ``dispatcher(*args, **kwargs)`` with
|
|
||||||
arguments from the NumPy function call returns an iterable of
|
|
||||||
array-like arguments to check for ``__array_function__``.
|
|
||||||
module : str, optional
|
|
||||||
__module__ attribute to set on new function, e.g., ``module='numpy'``.
|
|
||||||
By default, module is copied from the decorated function.
|
|
||||||
verify : bool, optional
|
|
||||||
If True, verify the that the signature of the dispatcher and decorated
|
|
||||||
function signatures match exactly: all required and optional arguments
|
|
||||||
should appear in order with the same names, but the default values for
|
|
||||||
all optional arguments should be ``None``. Only disable verification
|
|
||||||
if the dispatcher's signature needs to deviate for some particular
|
|
||||||
reason, e.g., because the function has a signature like
|
|
||||||
``func(*args, **kwargs)``.
|
|
||||||
docs_from_dispatcher : bool, optional
|
|
||||||
If True, copy docs from the dispatcher function onto the dispatched
|
|
||||||
function, rather than from the implementation. This is useful for
|
|
||||||
functions defined in C, which otherwise don't have docstrings.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Function suitable for decorating the implementation of a NumPy function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not ENABLE_ARRAY_FUNCTION:
|
|
||||||
# __array_function__ requires an explicit opt-in for now
|
|
||||||
def decorator(implementation):
|
|
||||||
if module is not None:
|
|
||||||
implementation.__module__ = module
|
|
||||||
if docs_from_dispatcher:
|
|
||||||
add_docstring(implementation, dispatcher.__doc__)
|
|
||||||
return implementation
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
def decorator(implementation):
|
|
||||||
if verify:
|
|
||||||
verify_matching_signatures(implementation, dispatcher)
|
|
||||||
|
|
||||||
if docs_from_dispatcher:
|
|
||||||
add_docstring(implementation, dispatcher.__doc__)
|
|
||||||
|
|
||||||
@functools.wraps(implementation)
|
|
||||||
def public_api(*args, **kwargs):
|
|
||||||
relevant_args = dispatcher(*args, **kwargs)
|
|
||||||
return implement_array_function(
|
|
||||||
implementation, public_api, relevant_args, args, kwargs)
|
|
||||||
|
|
||||||
if module is not None:
|
|
||||||
public_api.__module__ = module
|
|
||||||
|
|
||||||
# TODO: remove this when we drop Python 2 support (functools.wraps)
|
|
||||||
# adds __wrapped__ automatically in later versions)
|
|
||||||
public_api.__wrapped__ = implementation
|
|
||||||
|
|
||||||
return public_api
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def array_function_from_dispatcher(
|
|
||||||
implementation, module=None, verify=True, docs_from_dispatcher=True):
|
|
||||||
"""Like array_function_dispatcher, but with function arguments flipped."""
|
|
||||||
|
|
||||||
def decorator(dispatcher):
|
|
||||||
return array_function_dispatch(
|
|
||||||
dispatcher, module, verify=verify,
|
|
||||||
docs_from_dispatcher=docs_from_dispatcher)(implementation)
|
|
||||||
return decorator
|
|
|
@ -1,172 +0,0 @@
|
||||||
#patch 1
|
|
||||||
|
|
||||||
"""
|
|
||||||
========================
|
|
||||||
Random Number Generation
|
|
||||||
========================
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
Utility functions
|
|
||||||
==============================================================================
|
|
||||||
random_sample Uniformly distributed floats over ``[0, 1)``.
|
|
||||||
random Alias for `random_sample`.
|
|
||||||
bytes Uniformly distributed random bytes.
|
|
||||||
random_integers Uniformly distributed integers in a given range.
|
|
||||||
permutation Randomly permute a sequence / generate a random sequence.
|
|
||||||
shuffle Randomly permute a sequence in place.
|
|
||||||
seed Seed the random number generator.
|
|
||||||
choice Random sample from 1-D array.
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
Compatibility functions
|
|
||||||
==============================================================================
|
|
||||||
rand Uniformly distributed values.
|
|
||||||
randn Normally distributed values.
|
|
||||||
ranf Uniformly distributed floating point numbers.
|
|
||||||
randint Uniformly distributed integers in a given range.
|
|
||||||
==================== =========================================================
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
Univariate distributions
|
|
||||||
==============================================================================
|
|
||||||
beta Beta distribution over ``[0, 1]``.
|
|
||||||
binomial Binomial distribution.
|
|
||||||
chisquare :math:`\\chi^2` distribution.
|
|
||||||
exponential Exponential distribution.
|
|
||||||
f F (Fisher-Snedecor) distribution.
|
|
||||||
gamma Gamma distribution.
|
|
||||||
geometric Geometric distribution.
|
|
||||||
gumbel Gumbel distribution.
|
|
||||||
hypergeometric Hypergeometric distribution.
|
|
||||||
laplace Laplace distribution.
|
|
||||||
logistic Logistic distribution.
|
|
||||||
lognormal Log-normal distribution.
|
|
||||||
logseries Logarithmic series distribution.
|
|
||||||
negative_binomial Negative binomial distribution.
|
|
||||||
noncentral_chisquare Non-central chi-square distribution.
|
|
||||||
noncentral_f Non-central F distribution.
|
|
||||||
normal Normal / Gaussian distribution.
|
|
||||||
pareto Pareto distribution.
|
|
||||||
poisson Poisson distribution.
|
|
||||||
power Power distribution.
|
|
||||||
rayleigh Rayleigh distribution.
|
|
||||||
triangular Triangular distribution.
|
|
||||||
uniform Uniform distribution.
|
|
||||||
vonmises Von Mises circular distribution.
|
|
||||||
wald Wald (inverse Gaussian) distribution.
|
|
||||||
weibull Weibull distribution.
|
|
||||||
zipf Zipf's distribution over ranked data.
|
|
||||||
==================== =========================================================
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
Multivariate distributions
|
|
||||||
==============================================================================
|
|
||||||
dirichlet Multivariate generalization of Beta distribution.
|
|
||||||
multinomial Multivariate generalization of the binomial distribution.
|
|
||||||
multivariate_normal Multivariate generalization of the normal distribution.
|
|
||||||
==================== =========================================================
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
Standard distributions
|
|
||||||
==============================================================================
|
|
||||||
standard_cauchy Standard Cauchy-Lorentz distribution.
|
|
||||||
standard_exponential Standard exponential distribution.
|
|
||||||
standard_gamma Standard Gamma distribution.
|
|
||||||
standard_normal Standard normal distribution.
|
|
||||||
standard_t Standard Student's t-distribution.
|
|
||||||
==================== =========================================================
|
|
||||||
|
|
||||||
==================== =========================================================
|
|
||||||
Internal functions
|
|
||||||
==============================================================================
|
|
||||||
get_state Get tuple representing internal state of generator.
|
|
||||||
set_state Set state of generator.
|
|
||||||
==================== =========================================================
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import division, absolute_import, print_function
|
|
||||||
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
'beta',
|
|
||||||
'binomial',
|
|
||||||
'bytes',
|
|
||||||
'chisquare',
|
|
||||||
'choice',
|
|
||||||
'dirichlet',
|
|
||||||
'exponential',
|
|
||||||
'f',
|
|
||||||
'gamma',
|
|
||||||
'geometric',
|
|
||||||
'get_state',
|
|
||||||
'gumbel',
|
|
||||||
'hypergeometric',
|
|
||||||
'laplace',
|
|
||||||
'logistic',
|
|
||||||
'lognormal',
|
|
||||||
'logseries',
|
|
||||||
'multinomial',
|
|
||||||
'multivariate_normal',
|
|
||||||
'negative_binomial',
|
|
||||||
'noncentral_chisquare',
|
|
||||||
'noncentral_f',
|
|
||||||
'normal',
|
|
||||||
'pareto',
|
|
||||||
'permutation',
|
|
||||||
'poisson',
|
|
||||||
'power',
|
|
||||||
'rand',
|
|
||||||
'randint',
|
|
||||||
'randn',
|
|
||||||
'random_integers',
|
|
||||||
'random_sample',
|
|
||||||
'rayleigh',
|
|
||||||
'seed',
|
|
||||||
'set_state',
|
|
||||||
'shuffle',
|
|
||||||
'standard_cauchy',
|
|
||||||
'standard_exponential',
|
|
||||||
'standard_gamma',
|
|
||||||
'standard_normal',
|
|
||||||
'standard_t',
|
|
||||||
'triangular',
|
|
||||||
'uniform',
|
|
||||||
'vonmises',
|
|
||||||
'wald',
|
|
||||||
'weibull',
|
|
||||||
'zipf'
|
|
||||||
]
|
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
|
||||||
warnings.filterwarnings("ignore", message="numpy.ndarray size changed")
|
|
||||||
try:
|
|
||||||
from .mtrand import *
|
|
||||||
# Some aliases:
|
|
||||||
ranf = random = sample = random_sample
|
|
||||||
__all__.extend(['ranf', 'random', 'sample'])
|
|
||||||
except:
|
|
||||||
warnings.warn("numpy.random is not available when using multiple interpreters!")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def __RandomState_ctor():
|
|
||||||
"""Return a RandomState instance.
|
|
||||||
|
|
||||||
This function exists solely to assist (un)pickling.
|
|
||||||
|
|
||||||
Note that the state of the RandomState returned here is irrelevant, as this function's
|
|
||||||
entire purpose is to return a newly allocated RandomState whose state pickle can set.
|
|
||||||
Consequently the RandomState returned by this function is a freshly allocated copy
|
|
||||||
with a seed=0.
|
|
||||||
|
|
||||||
See https://github.com/numpy/numpy/issues/4763 for a detailed discussion
|
|
||||||
|
|
||||||
"""
|
|
||||||
return RandomState(seed=0)
|
|
||||||
|
|
||||||
from numpy._pytesttester import PytestTester
|
|
||||||
test = PytestTester(__name__)
|
|
||||||
del PytestTester
|
|
|
@ -3,13 +3,13 @@ import traceback
|
||||||
import json
|
import json
|
||||||
import inspect
|
import inspect
|
||||||
|
|
||||||
|
__python_exception__ = ""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
pass
|
pass
|
||||||
sys.stdout.flush()
|
sys.stdout.flush()
|
||||||
sys.stderr.flush()
|
sys.stderr.flush()
|
||||||
except Exception as ex:
|
except Exception as ex:
|
||||||
|
__python_exception__ = ex
|
||||||
try:
|
try:
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
finally:
|
finally:
|
||||||
|
|
|
@ -1,50 +0,0 @@
|
||||||
def __is_numpy_array(x):
|
|
||||||
return str(type(x))== "<class 'numpy.ndarray'>"
|
|
||||||
|
|
||||||
def __maybe_serialize_ndarray_metadata(x):
|
|
||||||
return __serialize_ndarray_metadata(x) if __is_numpy_array(x) else x
|
|
||||||
|
|
||||||
|
|
||||||
def __serialize_ndarray_metadata(x):
|
|
||||||
return {"address": x.__array_interface__['data'][0],
|
|
||||||
"shape": x.shape,
|
|
||||||
"strides": x.strides,
|
|
||||||
"dtype": str(x.dtype),
|
|
||||||
"_is_numpy_array": True} if __is_numpy_array(x) else x
|
|
||||||
|
|
||||||
|
|
||||||
def __serialize_list(x):
|
|
||||||
import json
|
|
||||||
return json.dumps(__recursive_serialize_list(x))
|
|
||||||
|
|
||||||
|
|
||||||
def __serialize_dict(x):
|
|
||||||
import json
|
|
||||||
return json.dumps(__recursive_serialize_dict(x))
|
|
||||||
|
|
||||||
def __recursive_serialize_list(x):
|
|
||||||
out = []
|
|
||||||
for i in x:
|
|
||||||
if __is_numpy_array(i):
|
|
||||||
out.append(__serialize_ndarray_metadata(i))
|
|
||||||
elif isinstance(i, (list, tuple)):
|
|
||||||
out.append(__recursive_serialize_list(i))
|
|
||||||
elif isinstance(i, dict):
|
|
||||||
out.append(__recursive_serialize_dict(i))
|
|
||||||
else:
|
|
||||||
out.append(i)
|
|
||||||
return out
|
|
||||||
|
|
||||||
def __recursive_serialize_dict(x):
|
|
||||||
out = {}
|
|
||||||
for k in x:
|
|
||||||
v = x[k]
|
|
||||||
if __is_numpy_array(v):
|
|
||||||
out[k] = __serialize_ndarray_metadata(v)
|
|
||||||
elif isinstance(v, (list, tuple)):
|
|
||||||
out[k] = __recursive_serialize_list(v)
|
|
||||||
elif isinstance(v, dict):
|
|
||||||
out[k] = __recursive_serialize_dict(v)
|
|
||||||
else:
|
|
||||||
out[k] = v
|
|
||||||
return out
|
|
|
@ -0,0 +1,87 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import javax.annotation.concurrent.NotThreadSafe;
|
||||||
|
|
||||||
|
@NotThreadSafe
|
||||||
|
public class TestPythonContextManager {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testInt() throws Exception{
|
||||||
|
Python.setContext("context1");
|
||||||
|
Python.exec("a = 1");
|
||||||
|
Python.setContext("context2");
|
||||||
|
Python.exec("a = 2");
|
||||||
|
Python.setContext("context3");
|
||||||
|
Python.exec("a = 3");
|
||||||
|
|
||||||
|
|
||||||
|
Python.setContext("context1");
|
||||||
|
Assert.assertEquals(1, PythonExecutioner.getVariable("a").toInt());
|
||||||
|
|
||||||
|
Python.setContext("context2");
|
||||||
|
Assert.assertEquals(2, PythonExecutioner.getVariable("a").toInt());
|
||||||
|
|
||||||
|
Python.setContext("context3");
|
||||||
|
Assert.assertEquals(3, PythonExecutioner.getVariable("a").toInt());
|
||||||
|
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testNDArray() throws Exception{
|
||||||
|
Python.setContext("context1");
|
||||||
|
Python.exec("import numpy as np");
|
||||||
|
Python.exec("a = np.zeros((3,2)) + 1");
|
||||||
|
|
||||||
|
Python.setContext("context2");
|
||||||
|
Python.exec("import numpy as np");
|
||||||
|
Python.exec("a = np.zeros((3,2)) + 2");
|
||||||
|
|
||||||
|
Python.setContext("context3");
|
||||||
|
Python.exec("import numpy as np");
|
||||||
|
Python.exec("a = np.zeros((3,2)) + 3");
|
||||||
|
|
||||||
|
Python.setContext("context1");
|
||||||
|
Python.exec("a += 1");
|
||||||
|
|
||||||
|
Python.setContext("context2");
|
||||||
|
Python.exec("a += 2");
|
||||||
|
|
||||||
|
Python.setContext("context3");
|
||||||
|
Python.exec("a += 3");
|
||||||
|
|
||||||
|
INDArray arr = Nd4j.create(DataType.DOUBLE, 3, 2);
|
||||||
|
Python.setContext("context1");
|
||||||
|
Assert.assertEquals(arr.add(2), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray());
|
||||||
|
|
||||||
|
Python.setContext("context2");
|
||||||
|
Assert.assertEquals(arr.add(4), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray());
|
||||||
|
|
||||||
|
Python.setContext("context3");
|
||||||
|
Assert.assertEquals(arr.add(6), PythonExecutioner.getVariable("a").toNumpy().getNd4jArray());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,64 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.var;
|
||||||
|
import org.json.JSONArray;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@javax.annotation.concurrent.NotThreadSafe
|
||||||
|
public class TestPythonDict {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonDictFromMap() throws Exception{
|
||||||
|
Map<Object, Object> map = new HashMap<>();
|
||||||
|
map.put("a", 1);
|
||||||
|
map.put("b", "a");
|
||||||
|
map.put("1", Arrays.asList(1, 2, 3, "4", Arrays.asList("x", 2.3)));
|
||||||
|
Map<Object, Object> innerMap = new HashMap<>();
|
||||||
|
innerMap.put("k", 32);
|
||||||
|
map.put("inner", innerMap);
|
||||||
|
map.put("ndarray", Nd4j.linspace(1, 4, 4));
|
||||||
|
innerMap.put("ndarray", Nd4j.linspace(5, 8, 4));
|
||||||
|
PythonObject dict = new PythonObject(map);
|
||||||
|
assertEquals(map.size(), Python.len(dict).toInt());
|
||||||
|
assertEquals("{'a': 1, '1': [1, 2, 3, '4', ['" +
|
||||||
|
"x', 2.3]], 'b': 'a', 'inner': {'k': 32," +
|
||||||
|
" 'ndarray': array([5., 6., 7., 8.], dty" +
|
||||||
|
"pe=float32)}, 'ndarray': array([1., 2., " +
|
||||||
|
"3., 4.], dtype=float32)}",
|
||||||
|
dict.toString());
|
||||||
|
Map map2 = dict.toMap();
|
||||||
|
PythonObject dict2 = new PythonObject(map2);
|
||||||
|
assertEquals(dict.toString(), dict2.toString());
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,75 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.python;
|
|
||||||
|
|
||||||
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
@javax.annotation.concurrent.NotThreadSafe
|
|
||||||
public class TestPythonExecutionSandbox {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testInt(){
|
|
||||||
PythonExecutioner.setInterpreter("interp1");
|
|
||||||
PythonExecutioner.exec("a = 1");
|
|
||||||
PythonExecutioner.setInterpreter("interp2");
|
|
||||||
PythonExecutioner.exec("a = 2");
|
|
||||||
PythonExecutioner.setInterpreter("interp3");
|
|
||||||
PythonExecutioner.exec("a = 3");
|
|
||||||
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("interp1");
|
|
||||||
Assert.assertEquals(1, PythonExecutioner.evalInteger("a"));
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("interp2");
|
|
||||||
Assert.assertEquals(2, PythonExecutioner.evalInteger("a"));
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("interp3");
|
|
||||||
Assert.assertEquals(3, PythonExecutioner.evalInteger("a"));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNDArray(){
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
PythonExecutioner.exec("import numpy as np");
|
|
||||||
PythonExecutioner.exec("a = np.zeros(5)");
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
//PythonExecutioner.exec("import numpy as np");
|
|
||||||
PythonExecutioner.exec("a = np.zeros(5)");
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
PythonExecutioner.exec("a += 2");
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
PythonExecutioner.exec("a += 3");
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
//PythonExecutioner.exec("import numpy as np");
|
|
||||||
// PythonExecutioner.exec("a = np.zeros(5)");
|
|
||||||
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
Assert.assertEquals(25, PythonExecutioner.evalNdArray("a").getNd4jArray().sum().getDouble(), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNumpyRandom(){
|
|
||||||
PythonExecutioner.setInterpreter("main");
|
|
||||||
PythonExecutioner.exec("import numpy as np; print(np.random.randint(5))");
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -15,13 +15,17 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
import org.junit.Assert;
|
import org.junit.Assert;
|
||||||
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
import static org.junit.Assert.assertEquals;
|
||||||
|
import static org.junit.Assert.fail;
|
||||||
|
|
||||||
|
|
||||||
@javax.annotation.concurrent.NotThreadSafe
|
@javax.annotation.concurrent.NotThreadSafe
|
||||||
|
@ -29,8 +33,8 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
|
|
||||||
@org.junit.Test
|
@org.junit.Test
|
||||||
public void testPythonSysVersion() {
|
public void testPythonSysVersion() throws PythonException {
|
||||||
PythonExecutioner.exec("import sys; print(sys.version)");
|
Python.exec("import sys; print(sys.version)");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -46,7 +50,7 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
String code = "z = x + ' ' + y";
|
String code = "z = x + ' ' + y";
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
|
||||||
String z = pyOutputs.getStrValue("z");
|
String z = pyOutputs.getStrValue("z");
|
||||||
|
|
||||||
|
@ -68,7 +72,7 @@ public class TestPythonExecutioner {
|
||||||
pyOutputs.addInt("z");
|
pyOutputs.addInt("z");
|
||||||
|
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
|
||||||
long z = pyOutputs.getIntValue("z");
|
long z = pyOutputs.getIntValue("z");
|
||||||
|
|
||||||
|
@ -92,9 +96,9 @@ public class TestPythonExecutioner {
|
||||||
pyOutputs.addList("z");
|
pyOutputs.addList("z");
|
||||||
|
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
|
||||||
Object[] z = pyOutputs.getListValue("z");
|
Object[] z = pyOutputs.getListValue("z").toArray();
|
||||||
|
|
||||||
Assert.assertEquals(z.length, x.length + y.length);
|
Assert.assertEquals(z.length, x.length + y.length);
|
||||||
|
|
||||||
|
@ -103,8 +107,7 @@ public class TestPythonExecutioner {
|
||||||
Number xNum = (Number) x[i];
|
Number xNum = (Number) x[i];
|
||||||
Number zNum = (Number) z[i];
|
Number zNum = (Number) z[i];
|
||||||
Assert.assertEquals(xNum.intValue(), zNum.intValue());
|
Assert.assertEquals(xNum.intValue(), zNum.intValue());
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
Assert.assertEquals(x[i], z[i]);
|
Assert.assertEquals(x[i], z[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,8 +117,7 @@ public class TestPythonExecutioner {
|
||||||
Number yNum = (Number) y[i];
|
Number yNum = (Number) y[i];
|
||||||
Number zNum = (Number) z[x.length + i];
|
Number zNum = (Number) z[x.length + i];
|
||||||
Assert.assertEquals(yNum.intValue(), zNum.intValue());
|
Assert.assertEquals(yNum.intValue(), zNum.intValue());
|
||||||
}
|
} else {
|
||||||
else {
|
|
||||||
Assert.assertEquals(y[i], z[x.length + i]);
|
Assert.assertEquals(y[i], z[x.length + i]);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -135,8 +137,8 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
String code = "z = x + y";
|
String code = "z = x + y";
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z");
|
||||||
|
|
||||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
|
|
||||||
|
@ -144,8 +146,9 @@ public class TestPythonExecutioner {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testTensorflowCustomAnaconda() {
|
@Ignore
|
||||||
PythonExecutioner.exec("import tensorflow as tf");
|
public void testTensorflowCustomAnaconda() throws PythonException {
|
||||||
|
Python.exec("import tensorflow as tf");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
@ -159,8 +162,8 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
String code = "z = x + y";
|
String code = "z = x + y";
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z");
|
||||||
|
|
||||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
}
|
}
|
||||||
|
@ -176,8 +179,8 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
String code = "z = x + y";
|
String code = "z = x + y";
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z");
|
||||||
|
|
||||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
}
|
}
|
||||||
|
@ -194,8 +197,8 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
String code = "z = x + y";
|
String code = "z = x + y";
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z");
|
||||||
|
|
||||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
|
|
||||||
|
@ -212,12 +215,91 @@ public class TestPythonExecutioner {
|
||||||
|
|
||||||
String code = "z = x + y";
|
String code = "z = x + y";
|
||||||
|
|
||||||
PythonExecutioner.exec(code, pyInputs, pyOutputs);
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
INDArray z = pyOutputs.getNDArrayValue("z").getNd4jArray();
|
INDArray z = pyOutputs.getNDArrayValue("z");
|
||||||
|
|
||||||
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
Assert.assertEquals(6.0, z.sum().getDouble(0), 1e-5);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testByteBufferInput() throws Exception{
|
||||||
|
//ByteBuffer buff = ByteBuffer.allocateDirect(3);
|
||||||
|
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
||||||
|
buff.putScalar(0, 97); // a
|
||||||
|
buff.putScalar(1, 98); // b
|
||||||
|
buff.putScalar(2, 99); // c
|
||||||
|
|
||||||
|
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
||||||
|
|
||||||
|
PythonVariables pyOutputs= new PythonVariables();
|
||||||
|
pyOutputs.addStr("out");
|
||||||
|
|
||||||
|
String code = "out = buff.decode()";
|
||||||
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
Assert.assertEquals("abc", pyOutputs.getStrValue("out"));
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testByteBufferOutputNoCopy() throws Exception{
|
||||||
|
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
||||||
|
buff.putScalar(0, 97); // a
|
||||||
|
buff.putScalar(1, 98); // b
|
||||||
|
buff.putScalar(2, 99); // c
|
||||||
|
|
||||||
|
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
||||||
|
|
||||||
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
pyOutputs.addBytes("buff"); // same name as input, because inplace update
|
||||||
|
|
||||||
|
String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97";
|
||||||
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
Assert.assertEquals("cba", pyOutputs.getBytesValue("buff").getString());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testByteBufferOutputWithCopy() throws Exception{
|
||||||
|
INDArray buff = Nd4j.zeros(new int[]{3}, DataType.BYTE);
|
||||||
|
buff.putScalar(0, 97); // a
|
||||||
|
buff.putScalar(1, 98); // b
|
||||||
|
buff.putScalar(2, 99); // c
|
||||||
|
|
||||||
|
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
pyInputs.addBytes("buff", new BytePointer(buff.data().pointer()));
|
||||||
|
|
||||||
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
pyOutputs.addBytes("out");
|
||||||
|
|
||||||
|
String code = "buff[0]=99\nbuff[1]=98\nbuff[2]=97\nout=bytes(buff)";
|
||||||
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
Assert.assertEquals("cba", pyOutputs.getBytesValue("out").getString());
|
||||||
|
}
|
||||||
|
@Test
|
||||||
|
public void testBadCode() throws Exception{
|
||||||
|
Python.setContext("badcode");
|
||||||
|
PythonVariables pyInputs = new PythonVariables();
|
||||||
|
PythonVariables pyOutputs = new PythonVariables();
|
||||||
|
|
||||||
|
pyInputs.addNDArray("x", Nd4j.zeros(DataType.LONG, 2, 3));
|
||||||
|
pyInputs.addNDArray("y", Nd4j.ones(DataType.LONG, 2, 3));
|
||||||
|
pyOutputs.addNDArray("z");
|
||||||
|
|
||||||
|
String code = "z = x + a";
|
||||||
|
|
||||||
|
try{
|
||||||
|
Python.exec(code, pyInputs, pyOutputs);
|
||||||
|
fail("No exception thrown");
|
||||||
|
} catch (PythonException pe ){
|
||||||
|
Assert.assertEquals("NameError: name 'a' is not defined", pe.getMessage());
|
||||||
|
}
|
||||||
|
|
||||||
|
Python.setMainContext();
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,326 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
import org.junit.Assert;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
|
||||||
|
@javax.annotation.concurrent.NotThreadSafe
|
||||||
|
public class TestPythonJob {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonJobBasic() throws Exception{
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code = "c = a + b";
|
||||||
|
PythonJob job = new PythonJob("job1", code, false);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addInt("a", 2);
|
||||||
|
inputs.addInt("b", 3);
|
||||||
|
|
||||||
|
PythonVariables outputs = new PythonVariables();
|
||||||
|
outputs.addInt("c");
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(5L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addFloat("a", 3.0);
|
||||||
|
inputs.addFloat("b", 4.0);
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addFloat("c");
|
||||||
|
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(7.0, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
|
||||||
|
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addNDArray("c");
|
||||||
|
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonJobReturnAllVariables()throws Exception{
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code = "c = a + b";
|
||||||
|
PythonJob job = new PythonJob("job1", code, false);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addInt("a", 2);
|
||||||
|
inputs.addInt("b", 3);
|
||||||
|
|
||||||
|
|
||||||
|
PythonVariables outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(5L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addFloat("a", 3.0);
|
||||||
|
inputs.addFloat("b", 4.0);
|
||||||
|
|
||||||
|
outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(7.0, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
|
||||||
|
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
|
||||||
|
|
||||||
|
outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMultiplePythonJobsParallel()throws Exception{
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code1 = "c = a + b";
|
||||||
|
PythonJob job1 = new PythonJob("job1", code1, false);
|
||||||
|
|
||||||
|
String code2 = "c = a - b";
|
||||||
|
PythonJob job2 = new PythonJob("job2", code2, false);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addInt("a", 2);
|
||||||
|
inputs.addInt("b", 3);
|
||||||
|
|
||||||
|
PythonVariables outputs = new PythonVariables();
|
||||||
|
outputs.addInt("c");
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(5L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(-1L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addFloat("a", 3.0);
|
||||||
|
inputs.addFloat("b", 4.0);
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addFloat("c");
|
||||||
|
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(7.0, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(-1L, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
|
||||||
|
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addNDArray("c");
|
||||||
|
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(9), outputs.getNDArrayValue("c"));
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).sub(1), outputs.getNDArrayValue("c"));
|
||||||
|
}
|
||||||
|
@Test
|
||||||
|
public void testPythonJobSetupRun()throws Exception{
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code = "five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b + five\n"+
|
||||||
|
" return {'c':c}\n\n";
|
||||||
|
PythonJob job = new PythonJob("job1", code, true);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addInt("a", 2);
|
||||||
|
inputs.addInt("b", 3);
|
||||||
|
|
||||||
|
PythonVariables outputs = new PythonVariables();
|
||||||
|
outputs.addInt("c");
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(10L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addFloat("a", 3.0);
|
||||||
|
inputs.addFloat("b", 4.0);
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addFloat("c");
|
||||||
|
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(12.0, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
|
||||||
|
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addNDArray("c");
|
||||||
|
|
||||||
|
|
||||||
|
job.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c"));
|
||||||
|
}
|
||||||
|
@Test
|
||||||
|
public void testPythonJobSetupRunAndReturnAllVariables()throws Exception{
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code = "five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b + five\n"+
|
||||||
|
" return {'c':c}\n\n";
|
||||||
|
PythonJob job = new PythonJob("job1", code, true);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addInt("a", 2);
|
||||||
|
inputs.addInt("b", 3);
|
||||||
|
|
||||||
|
|
||||||
|
PythonVariables outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(10L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addFloat("a", 3.0);
|
||||||
|
inputs.addFloat("b", 4.0);
|
||||||
|
|
||||||
|
outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(12.0, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
|
||||||
|
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
|
||||||
|
|
||||||
|
outputs = job.execAndReturnAllVariables(inputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c"));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testMultiplePythonJobsSetupRunParallel()throws Exception{
|
||||||
|
PythonContextManager.deleteNonMainContexts();
|
||||||
|
|
||||||
|
String code1 = "five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b + five\n"+
|
||||||
|
" return {'c':c}\n\n";
|
||||||
|
PythonJob job1 = new PythonJob("job1", code1, true);
|
||||||
|
|
||||||
|
String code2 = "five=None\n" +
|
||||||
|
"def setup():\n" +
|
||||||
|
" global five\n"+
|
||||||
|
" five = 5\n\n" +
|
||||||
|
"def run(a, b):\n" +
|
||||||
|
" c = a + b - five\n"+
|
||||||
|
" return {'c':c}\n\n";
|
||||||
|
PythonJob job2 = new PythonJob("job2", code2, true);
|
||||||
|
|
||||||
|
PythonVariables inputs = new PythonVariables();
|
||||||
|
inputs.addInt("a", 2);
|
||||||
|
inputs.addInt("b", 3);
|
||||||
|
|
||||||
|
PythonVariables outputs = new PythonVariables();
|
||||||
|
outputs.addInt("c");
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(10L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(0L, (long)outputs.getIntValue("c"));
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addFloat("a", 3.0);
|
||||||
|
inputs.addFloat("b", 4.0);
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addFloat("c");
|
||||||
|
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(12.0, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(2L, outputs.getFloatValue("c"), 1e-5);
|
||||||
|
|
||||||
|
|
||||||
|
inputs = new PythonVariables();
|
||||||
|
inputs.addNDArray("a", Nd4j.zeros(3, 2).add(4));
|
||||||
|
inputs.addNDArray("b", Nd4j.zeros(3, 2).add(5));
|
||||||
|
|
||||||
|
outputs = new PythonVariables();
|
||||||
|
outputs.addNDArray("c");
|
||||||
|
|
||||||
|
|
||||||
|
job1.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(14), outputs.getNDArrayValue("c"));
|
||||||
|
|
||||||
|
job2.exec(inputs, outputs);
|
||||||
|
|
||||||
|
assertEquals(Nd4j.zeros(3, 2).add(4), outputs.getNDArrayValue("c"));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,108 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.datavec.python;
|
||||||
|
|
||||||
|
|
||||||
|
import lombok.var;
|
||||||
|
import org.json.JSONArray;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.util.*;
|
||||||
|
|
||||||
|
import static org.junit.Assert.assertArrayEquals;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
@javax.annotation.concurrent.NotThreadSafe
|
||||||
|
public class TestPythonList {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonListFromIntArray() {
|
||||||
|
PythonObject pyList = new PythonObject(new Integer[]{1, 2, 3, 4, 5});
|
||||||
|
pyList.attr("append").call(6);
|
||||||
|
pyList.attr("append").call(7);
|
||||||
|
pyList.attr("append").call(8);
|
||||||
|
assertEquals(8, Python.len(pyList).toInt());
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
assertEquals(i + 1, pyList.get(i).toInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonListFromLongArray() {
|
||||||
|
PythonObject pyList = new PythonObject(new Long[]{1L, 2L, 3L, 4L, 5L});
|
||||||
|
pyList.attr("append").call(6);
|
||||||
|
pyList.attr("append").call(7);
|
||||||
|
pyList.attr("append").call(8);
|
||||||
|
assertEquals(8, Python.len(pyList).toInt());
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
assertEquals(i + 1, pyList.get(i).toInt());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonListFromDoubleArray() {
|
||||||
|
PythonObject pyList = new PythonObject(new Double[]{1., 2., 3., 4., 5.});
|
||||||
|
pyList.attr("append").call(6);
|
||||||
|
pyList.attr("append").call(7);
|
||||||
|
pyList.attr("append").call(8);
|
||||||
|
assertEquals(8, Python.len(pyList).toInt());
|
||||||
|
for (int i = 0; i < 8; i++) {
|
||||||
|
assertEquals(i + 1, pyList.get(i).toInt());
|
||||||
|
assertEquals((double) i + 1, pyList.get(i).toDouble(), 1e-5);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonListFromStringArray() {
|
||||||
|
PythonObject pyList = new PythonObject(new String[]{"abcd", "efg"});
|
||||||
|
pyList.attr("append").call("hijk");
|
||||||
|
pyList.attr("append").call("lmnop");
|
||||||
|
assertEquals("abcdefghijklmnop", new PythonObject("").attr("join").call(pyList).toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testPythonListFromMixedArray()throws Exception {
|
||||||
|
Map<Object, Object> map = new HashMap<>();
|
||||||
|
map.put(1, "a");
|
||||||
|
map.put("a", Arrays.asList("a", "b", "c"));
|
||||||
|
map.put("arr", Nd4j.linspace(1, 4, 4));
|
||||||
|
Object[] objs = new Object[]{
|
||||||
|
1, 2, "a", 3f, 4L, 5.0, Arrays.asList(10,
|
||||||
|
20, "b", 30f, 40L, 50.0, map
|
||||||
|
|
||||||
|
), map
|
||||||
|
};
|
||||||
|
PythonObject pyList = new PythonObject(objs);
|
||||||
|
System.out.println(pyList.toString());
|
||||||
|
String expectedStr = "[1, 2, 'a', 3.0, 4, 5.0, [10" +
|
||||||
|
", 20, 'b', 30.0, 40, 50.0, {'arr': array([1.," +
|
||||||
|
" 2., 3., 4.], dtype=float32), 1: 'a', 'a': [" +
|
||||||
|
"'a', 'b', 'c']}], {'arr': array([1., 2., 3.," +
|
||||||
|
" 4.], dtype=float32), 1: 'a', 'a': ['a', 'b', 'c']}]";
|
||||||
|
assertEquals(expectedStr, pyList.toString());
|
||||||
|
List objs2 = pyList.toList();
|
||||||
|
PythonObject pyList2 = new PythonObject(objs2);
|
||||||
|
assertEquals(pyList.toString(), pyList2.toString());
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -1,27 +0,0 @@
|
||||||
package org.datavec.python;
|
|
||||||
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
@javax.annotation.concurrent.NotThreadSafe
|
|
||||||
public class TestPythonSetupAndRun {
|
|
||||||
@Test
|
|
||||||
public void testPythonWithSetupAndRun() throws Exception{
|
|
||||||
String code = "def setup():" +
|
|
||||||
"global counter;counter=0\n" +
|
|
||||||
"def run(step):" +
|
|
||||||
"global counter;" +
|
|
||||||
"counter+=step;" +
|
|
||||||
"return {\"counter\":counter}";
|
|
||||||
PythonVariables pyInputs = new PythonVariables();
|
|
||||||
pyInputs.addInt("step", 2);
|
|
||||||
PythonVariables pyOutputs = new PythonVariables();
|
|
||||||
pyOutputs.addInt("counter");
|
|
||||||
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
|
|
||||||
assertEquals((long)pyOutputs.getIntValue("counter"), 2L);
|
|
||||||
pyInputs.addInt("step", 3);
|
|
||||||
PythonExecutioner.execWithSetupAndRun(code, pyInputs, pyOutputs);
|
|
||||||
assertEquals((long)pyOutputs.getIntValue("counter"), 5L);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -22,11 +22,14 @@
|
||||||
|
|
||||||
package org.datavec.python;
|
package org.datavec.python;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.BytePointer;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
|
||||||
import static junit.framework.TestCase.assertNotNull;
|
import static junit.framework.TestCase.assertNotNull;
|
||||||
import static junit.framework.TestCase.assertNull;
|
import static junit.framework.TestCase.assertNull;
|
||||||
|
@ -36,59 +39,50 @@ import static org.junit.Assert.assertTrue;
|
||||||
|
|
||||||
public class TestPythonVariables {
|
public class TestPythonVariables {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testImportNumpy(){
|
public void testDataAssociations() throws PythonException{
|
||||||
Nd4j.scalar(1.0);
|
|
||||||
System.out.println(System.getProperty("org.bytedeco.openblas.load"));
|
|
||||||
PythonExecutioner.exec("import numpy as np");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testDataAssociations() {
|
|
||||||
PythonVariables pythonVariables = new PythonVariables();
|
PythonVariables pythonVariables = new PythonVariables();
|
||||||
PythonVariables.Type[] types = {
|
PythonType[] types = {
|
||||||
PythonVariables.Type.INT,
|
PythonType.INT,
|
||||||
PythonVariables.Type.FLOAT,
|
PythonType.FLOAT,
|
||||||
PythonVariables.Type.STR,
|
PythonType.STR,
|
||||||
PythonVariables.Type.BOOL,
|
PythonType.BOOL,
|
||||||
PythonVariables.Type.DICT,
|
PythonType.DICT,
|
||||||
PythonVariables.Type.LIST,
|
PythonType.LIST,
|
||||||
PythonVariables.Type.LIST,
|
PythonType.LIST,
|
||||||
PythonVariables.Type.FILE,
|
PythonType.NDARRAY,
|
||||||
PythonVariables.Type.NDARRAY
|
PythonType.BYTES
|
||||||
};
|
};
|
||||||
|
|
||||||
NumpyArray npArr = new NumpyArray(Nd4j.scalar(1.0));
|
INDArray arr = Nd4j.scalar(1.0);
|
||||||
|
BytePointer bp = new BytePointer(arr.data().pointer());
|
||||||
Object[] values = {
|
Object[] values = {
|
||||||
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
||||||
new Object[]{1}, Arrays.asList(1),"type", npArr
|
new Object[]{1}, Arrays.asList(1), arr, bp
|
||||||
};
|
};
|
||||||
|
|
||||||
Object[] expectedValues = {
|
Object[] expectedValues = {
|
||||||
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
1L,1.0,"1",true, Collections.singletonMap("1",1),
|
||||||
new Object[]{1}, new Object[]{1},"type", npArr
|
Arrays.asList(1), Arrays.asList(1), arr, bp
|
||||||
};
|
};
|
||||||
|
|
||||||
for(int i = 0; i < types.length; i++) {
|
for(int i = 0; i < types.length; i++) {
|
||||||
testInsertGet(pythonVariables,types[i].name() + i,values[i],types[i],expectedValues[i]);
|
testInsertGet(pythonVariables,types[i].getName().name() + i,values[i],types[i],expectedValues[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEquals(types.length,pythonVariables.getVariables().length);
|
assertEquals(types.length,pythonVariables.getVariables().length);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonVariables.Type type,Object expectedValue) {
|
private void testInsertGet(PythonVariables pythonVariables,String key,Object value,PythonType type,Object expectedValue) throws PythonException{
|
||||||
pythonVariables.add(key, type);
|
pythonVariables.add(key, type);
|
||||||
assertNull(pythonVariables.getValue(key));
|
assertNull(pythonVariables.getValue(key));
|
||||||
pythonVariables.setValue(key,value);
|
pythonVariables.setValue(key,value);
|
||||||
assertNotNull(pythonVariables.getValue(key));
|
assertNotNull(pythonVariables.getValue(key));
|
||||||
Object actualValue = pythonVariables.getValue(key);
|
Object actualValue = pythonVariables.getValue(key);
|
||||||
if (expectedValue instanceof Object[]){
|
if (expectedValue instanceof Object[]){
|
||||||
assertTrue(actualValue instanceof Object[]);
|
assertTrue(actualValue instanceof List);
|
||||||
Object[] actualArr = (Object[])actualValue;
|
Object[] actualArr = ((List)actualValue).toArray();
|
||||||
Object[] expectedArr = (Object[])expectedValue;
|
Object[] expectedArr = (Object[])expectedValue;
|
||||||
assertArrayEquals(expectedArr, actualArr);
|
assertArrayEquals(expectedArr, actualArr);
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,7 +44,7 @@ public class TestSerde {
|
||||||
String yaml = y.serialize(t);
|
String yaml = y.serialize(t);
|
||||||
String json = j.serialize(t);
|
String json = j.serialize(t);
|
||||||
|
|
||||||
Transform t2 = y.deserializeTransform(json);
|
Transform t2 = y.deserializeTransform(yaml);
|
||||||
Transform t3 = j.deserializeTransform(json);
|
Transform t3 = j.deserializeTransform(json);
|
||||||
assertEquals(t, t2);
|
assertEquals(t, t2);
|
||||||
assertEquals(t, t3);
|
assertEquals(t, t3);
|
||||||
|
|
|
@ -148,7 +148,7 @@
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<version>${maven-surefire-plugin.version}</version>
|
<version>${maven-surefire-plugin.version}</version>
|
||||||
<configuration>
|
<configuration>
|
||||||
<argLine>-Ddtype=float</argLine>
|
<argLine>-Ddtype=float -Dfile.encoding=UTF-8</argLine>
|
||||||
<!--
|
<!--
|
||||||
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
By default: Surefire will set the classpath based on the manifest. Because tests are not included
|
||||||
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
|
in the JAR, any tests that rely on class path scanning for resources in the tests directory will not
|
||||||
|
|
|
@ -56,6 +56,11 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
||||||
Nd4j.setDataType(DataType.DOUBLE);
|
Nd4j.setDataType(DataType.DOUBLE);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 90000L;
|
||||||
|
}
|
||||||
|
|
||||||
private static class GradientCheckSimpleScenario {
|
private static class GradientCheckSimpleScenario {
|
||||||
private final ILossFunction lf;
|
private final ILossFunction lf;
|
||||||
private final Activation act;
|
private final Activation act;
|
||||||
|
@ -159,9 +164,8 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
|
.dist(new NormalDistribution(0, 1.0)).seed(12345L).list()
|
||||||
.layer(0, new GravesBidirectionalLSTM.Builder().nIn(nIn).nOut(layerSize)
|
.layer(0, new SimpleRnn.Builder().nIn(nIn).nOut(2).activation(Activation.TANH).build())
|
||||||
.activation(Activation.TANH).build())
|
.layer(1, new GravesBidirectionalLSTM.Builder().nIn(2).nOut(layerSize)
|
||||||
.layer(1, new GravesBidirectionalLSTM.Builder().nIn(layerSize).nOut(layerSize)
|
|
||||||
.activation(Activation.TANH).build())
|
.activation(Activation.TANH).build())
|
||||||
.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
.layer(2, new RnnOutputLayer.Builder(LossFunctions.LossFunction.MCXENT)
|
||||||
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).build())
|
.activation(Activation.SOFTMAX).nIn(layerSize).nOut(nOut).build())
|
||||||
|
@ -390,24 +394,24 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
||||||
Nd4j.getRandom().setSeed(12345);
|
Nd4j.getRandom().setSeed(12345);
|
||||||
//Idea: RNN input, global pooling, OutputLayer - with "per example" mask arrays
|
//Idea: RNN input, global pooling, OutputLayer - with "per example" mask arrays
|
||||||
|
|
||||||
int mb = 10;
|
int mb = 4;
|
||||||
int tsLength = 5;
|
int tsLength = 5;
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
|
||||||
.dataType(DataType.DOUBLE)
|
.dataType(DataType.DOUBLE)
|
||||||
.weightInit(new NormalDistribution(0,2))
|
.weightInit(new NormalDistribution(0,2))
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.list()
|
.list()
|
||||||
.layer(new LSTM.Builder().nIn(10).nOut(10).build())
|
.layer(new LSTM.Builder().nIn(3).nOut(3).build())
|
||||||
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build())
|
.layer(new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build())
|
||||||
.layer(new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build())
|
.layer(new OutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).build())
|
||||||
.setInputType(InputType.recurrent(10))
|
.setInputType(InputType.recurrent(3))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
MultiLayerNetwork net = new MultiLayerNetwork(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
INDArray f = Nd4j.rand(new int[]{mb, 10, tsLength});
|
INDArray f = Nd4j.rand(new int[]{mb, 3, tsLength});
|
||||||
INDArray l = TestUtils.randomOneHot(mb, 10);
|
INDArray l = TestUtils.randomOneHot(mb, 3);
|
||||||
INDArray lm = TestUtils.randomBernoulli(mb, 1);
|
INDArray lm = TestUtils.randomBernoulli(mb, 1);
|
||||||
|
|
||||||
assertTrue(lm.sumNumber().intValue() > 0);
|
assertTrue(lm.sumNumber().intValue() > 0);
|
||||||
|
@ -449,18 +453,18 @@ public class GradientCheckTestsMasking extends BaseDL4JTest {
|
||||||
.updater(new NoOp())
|
.updater(new NoOp())
|
||||||
.graphBuilder()
|
.graphBuilder()
|
||||||
.addInputs("in")
|
.addInputs("in")
|
||||||
.layer("0", new LSTM.Builder().nIn(10).nOut(10).build(), "in")
|
.layer("0", new LSTM.Builder().nIn(3).nOut(3).build(), "in")
|
||||||
.layer("1", new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build(), "0")
|
.layer("1", new GlobalPoolingLayer.Builder().poolingType(PoolingType.AVG).build(), "0")
|
||||||
.layer("out", new OutputLayer.Builder().nIn(10).nOut(10).activation(Activation.SOFTMAX).build(), "1")
|
.layer("out", new OutputLayer.Builder().nIn(3).nOut(3).activation(Activation.SOFTMAX).build(), "1")
|
||||||
.setOutputs("out")
|
.setOutputs("out")
|
||||||
.setInputTypes(InputType.recurrent(10))
|
.setInputTypes(InputType.recurrent(3))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
ComputationGraph net = new ComputationGraph(conf);
|
ComputationGraph net = new ComputationGraph(conf);
|
||||||
net.init();
|
net.init();
|
||||||
|
|
||||||
INDArray f = Nd4j.rand(new int[]{mb, 10, tsLength});
|
INDArray f = Nd4j.rand(new int[]{mb, 3, tsLength});
|
||||||
INDArray l = TestUtils.randomOneHot(mb, 10);
|
INDArray l = TestUtils.randomOneHot(mb, 3);
|
||||||
INDArray lm = TestUtils.randomBernoulli(mb, 1);
|
INDArray lm = TestUtils.randomBernoulli(mb, 1);
|
||||||
|
|
||||||
assertTrue(lm.sumNumber().intValue() > 0);
|
assertTrue(lm.sumNumber().intValue() > 0);
|
||||||
|
|
|
@ -37,7 +37,7 @@
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
<configuration>
|
||||||
<argLine>-Ddtype=float -Xmx8g -Dtest.solr.allowed.securerandom=NativePRNG</argLine>
|
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g -Dtest.solr.allowed.securerandom=NativePRNG</argLine>
|
||||||
<includes>
|
<includes>
|
||||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
<!-- Default setting only runs tests that start/end with "Test" -->
|
||||||
<include>*.java</include>
|
<include>*.java</include>
|
||||||
|
|
|
@ -18,6 +18,8 @@ package org.deeplearning4j.nn.dataimport.solr.client.solrj.io.stream;
|
||||||
|
|
||||||
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;
|
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;
|
||||||
import com.carrotsearch.randomizedtesting.ThreadFilter;
|
import com.carrotsearch.randomizedtesting.ThreadFilter;
|
||||||
|
|
||||||
|
import java.security.SecureRandom;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
@ -44,6 +46,18 @@ import org.nd4j.rng.deallocator.NativeRandomDeallocator;
|
||||||
})
|
})
|
||||||
public class TupleStreamDataSetIteratorTest extends SolrCloudTestCase {
|
public class TupleStreamDataSetIteratorTest extends SolrCloudTestCase {
|
||||||
|
|
||||||
|
static {
|
||||||
|
/*
|
||||||
|
This is a hack around the backend-dependent nature of secure random implementations
|
||||||
|
though we can set the secure random algorithm in our pom.xml files (via maven surefire and test.solr.allowed.securerandom)
|
||||||
|
there isn't a mechanism that is completely platform independent.
|
||||||
|
By setting it there (for example, to NativePRNG) that makes it pass on some platforms like Linux but fails on some JVMs on Windows
|
||||||
|
For testing purposes, we don't need strict guarantees around RNG, hence we don't want to enforce the RNG algorithm
|
||||||
|
*/
|
||||||
|
String algorithm = new SecureRandom().getAlgorithm();
|
||||||
|
System.setProperty("test.solr.allowed.securerandom", algorithm);
|
||||||
|
}
|
||||||
|
|
||||||
public static class PrivateDeallocatorThreadsFilter implements ThreadFilter {
|
public static class PrivateDeallocatorThreadsFilter implements ThreadFilter {
|
||||||
/**
|
/**
|
||||||
* Reject deallocator threads over whose cleanup this test has no control.
|
* Reject deallocator threads over whose cleanup this test has no control.
|
||||||
|
|
|
@ -66,7 +66,7 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
|
||||||
for (int i = 0; i < 7; i++) {
|
for (int i = 0; i < 7; i++) {
|
||||||
INDArray vector = deepWalk.getVertexVector(i);
|
INDArray vector = deepWalk.getVertexVector(i);
|
||||||
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
||||||
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
|
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 8);
|
||||||
|
@ -182,10 +182,10 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
|
||||||
|
|
||||||
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
if (relError > MAX_REL_ERROR && absErr > MIN_ABS_ERROR)
|
||||||
fail(msg);
|
fail(msg);
|
||||||
else
|
// else
|
||||||
System.out.println(msg);
|
// System.out.println(msg);
|
||||||
}
|
}
|
||||||
System.out.println();
|
// System.out.println();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
@ -216,7 +216,7 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
|
||||||
for (int i = 0; i < nVertices; i++) {
|
for (int i = 0; i < nVertices; i++) {
|
||||||
INDArray vector = deepWalk.getVertexVector(i);
|
INDArray vector = deepWalk.getVertexVector(i);
|
||||||
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
assertArrayEquals(new long[] {vectorSize}, vector.shape());
|
||||||
System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
// System.out.println(Arrays.toString(vector.dup().data().asFloat()));
|
||||||
}
|
}
|
||||||
|
|
||||||
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 10);
|
GraphWalkIterator<String> iter = new RandomWalkIterator<>(graph, 10);
|
||||||
|
@ -295,8 +295,8 @@ public class DeepWalkGradientCheck extends BaseDL4JTest {
|
||||||
|
|
||||||
if (relError > MAX_REL_ERROR && absErr > minAbsError)
|
if (relError > MAX_REL_ERROR && absErr > minAbsError)
|
||||||
fail(msg);
|
fail(msg);
|
||||||
else
|
// else
|
||||||
System.out.println(msg);
|
// System.out.println(msg);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
<configuration>
|
||||||
<argLine>-Ddtype=float -Xmx8g -Dtest.solr.allowed.securerandom=NativePRNG</argLine>
|
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g -Dtest.solr.allowed.securerandom=NativePRNG</argLine>
|
||||||
<includes>
|
<includes>
|
||||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
<!-- Default setting only runs tests that start/end with "Test" -->
|
||||||
<include>*.java</include>
|
<include>*.java</include>
|
||||||
|
|
|
@ -18,6 +18,7 @@ package org.deeplearning4j.nn.modelexport.solr.handler;
|
||||||
|
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
|
import java.security.SecureRandom;
|
||||||
|
|
||||||
import com.carrotsearch.randomizedtesting.ThreadFilter;
|
import com.carrotsearch.randomizedtesting.ThreadFilter;
|
||||||
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;
|
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters;
|
||||||
|
@ -49,6 +50,19 @@ import org.nd4j.rng.deallocator.NativeRandomDeallocator;
|
||||||
})
|
})
|
||||||
public class ModelTupleStreamIntegrationTest extends SolrCloudTestCase {
|
public class ModelTupleStreamIntegrationTest extends SolrCloudTestCase {
|
||||||
|
|
||||||
|
static {
|
||||||
|
/*
|
||||||
|
This is a hack around the backend-dependent nature of secure random implementations
|
||||||
|
though we can set the secure random algorithm in our pom.xml files (via maven surefire and test.solr.allowed.securerandom)
|
||||||
|
there isn't a mechanism that is completely platform independent.
|
||||||
|
By setting it there (for example, to NativePRNG) that makes it pass on some platforms like Linux but fails on some JVMs on Windows
|
||||||
|
For testing purposes, we don't need strict guarantees around RNG, hence we don't want to enforce the RNG algorithm
|
||||||
|
*/
|
||||||
|
String algorithm = new SecureRandom().getAlgorithm();
|
||||||
|
System.setProperty("test.solr.allowed.securerandom", algorithm);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public static class PrivateDeallocatorThreadsFilter implements ThreadFilter {
|
public static class PrivateDeallocatorThreadsFilter implements ThreadFilter {
|
||||||
/**
|
/**
|
||||||
* Reject deallocator threads over whose cleanup this test has no control.
|
* Reject deallocator threads over whose cleanup this test has no control.
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.nn.modelexport.solr.handler;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.nio.file.Files;
|
import java.nio.file.Files;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
|
import java.security.SecureRandom;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
@ -58,6 +59,18 @@ import org.nd4j.linalg.lossfunctions.LossFunctions;
|
||||||
|
|
||||||
public class ModelTupleStreamTest {
|
public class ModelTupleStreamTest {
|
||||||
|
|
||||||
|
static {
|
||||||
|
/*
|
||||||
|
This is a hack around the backend-dependent nature of secure random implementations
|
||||||
|
though we can set the secure random algorithm in our pom.xml files (via maven surefire and test.solr.allowed.securerandom)
|
||||||
|
there isn't a mechanism that is completely platform independent.
|
||||||
|
By setting it there (for example, to NativePRNG) that makes it pass on some platforms like Linux but fails on some JVMs on Windows
|
||||||
|
For testing purposes, we don't need strict guarantees around RNG, hence we don't want to enforce the RNG algorithm
|
||||||
|
*/
|
||||||
|
String algorithm = new SecureRandom().getAlgorithm();
|
||||||
|
System.setProperty("test.solr.allowed.securerandom", algorithm);
|
||||||
|
}
|
||||||
|
|
||||||
protected List<float[]> floatsList(int numFloats) {
|
protected List<float[]> floatsList(int numFloats) {
|
||||||
final List<float[]> floatsList = new ArrayList<float[]>();
|
final List<float[]> floatsList = new ArrayList<float[]>();
|
||||||
final float[] floats0 = new float[numFloats];
|
final float[] floats0 = new float[numFloats];
|
||||||
|
|
|
@ -36,7 +36,7 @@
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
<groupId>org.apache.maven.plugins</groupId>
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
<artifactId>maven-surefire-plugin</artifactId>
|
||||||
<configuration>
|
<configuration>
|
||||||
<argLine>-Ddtype=float -Xmx8g</argLine>
|
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g</argLine>
|
||||||
<includes>
|
<includes>
|
||||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
<!-- Default setting only runs tests that start/end with "Test" -->
|
||||||
<include>*.java</include>
|
<include>*.java</include>
|
||||||
|
|
|
@ -38,6 +38,11 @@ public class KMeansTest extends BaseDL4JTest {
|
||||||
|
|
||||||
private boolean[] useKMeansPlusPlus = {true, false};
|
private boolean[] useKMeansPlusPlus = {true, false};
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 60000L;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testKMeans() {
|
public void testKMeans() {
|
||||||
Nd4j.getRandom().setSeed(7);
|
Nd4j.getRandom().setSeed(7);
|
||||||
|
|
|
@ -16,25 +16,19 @@
|
||||||
|
|
||||||
package org.deeplearning4j.models;
|
package org.deeplearning4j.models;
|
||||||
|
|
||||||
import org.junit.rules.Timeout;
|
|
||||||
import org.nd4j.shade.guava.io.Files;
|
|
||||||
import org.nd4j.shade.guava.primitives.Doubles;
|
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.io.FileUtils;
|
import org.apache.commons.io.FileUtils;
|
||||||
import org.apache.commons.lang.ArrayUtils;
|
import org.apache.commons.lang.ArrayUtils;
|
||||||
import org.apache.commons.lang3.RandomUtils;
|
import org.apache.commons.lang3.RandomUtils;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
|
||||||
import org.deeplearning4j.models.sequencevectors.serialization.VocabWordFactory;
|
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
|
||||||
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
import org.deeplearning4j.models.embeddings.WeightLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
import org.deeplearning4j.models.embeddings.loader.VectorsConfiguration;
|
||||||
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
|
||||||
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
|
||||||
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
|
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors;
|
||||||
|
import org.deeplearning4j.models.sequencevectors.SequenceVectors;
|
||||||
|
import org.deeplearning4j.models.sequencevectors.serialization.VocabWordFactory;
|
||||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||||
import org.deeplearning4j.models.word2vec.Word2Vec;
|
import org.deeplearning4j.models.word2vec.Word2Vec;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
|
@ -48,11 +42,16 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFac
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||||
import org.junit.Before;
|
import org.junit.Before;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
|
import org.junit.Rule;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.junit.rules.Timeout;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||||
import org.nd4j.resources.Resources;
|
import org.nd4j.resources.Resources;
|
||||||
|
import org.nd4j.shade.guava.primitives.Doubles;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
@ -272,7 +271,14 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testFullModelSerialization() throws Exception {
|
public void testFullModelSerialization() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||||
|
|
||||||
|
|
||||||
SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
|
SentenceIterator iter = UimaSentenceIterator.createWithPath(inputFile.getAbsolutePath());
|
||||||
// Split on white spaces in the line to get words
|
// Split on white spaces in the line to get words
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
|
@ -892,5 +898,4 @@ public class WordVectorSerializerTest extends BaseDL4JTest {
|
||||||
fail(e.getMessage());
|
fail(e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -159,6 +159,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWord2VecCBOW() throws Exception {
|
public void testWord2VecCBOW() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
|
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
|
@ -188,6 +193,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWord2VecMultiEpoch() throws Exception {
|
public void testWord2VecMultiEpoch() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
SentenceIterator iter;
|
SentenceIterator iter;
|
||||||
if(isIntegrationTests()){
|
if(isIntegrationTests()){
|
||||||
iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
|
@ -220,6 +230,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void reproducibleResults_ForMultipleRuns() throws Exception {
|
public void reproducibleResults_ForMultipleRuns() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
log.info("reproducibleResults_ForMultipleRuns");
|
log.info("reproducibleResults_ForMultipleRuns");
|
||||||
val shakespear = new ClassPathResource("big/rnj.txt");
|
val shakespear = new ClassPathResource("big/rnj.txt");
|
||||||
val basic = new ClassPathResource("big/rnj.txt");
|
val basic = new ClassPathResource("big/rnj.txt");
|
||||||
|
@ -274,6 +289,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testRunWord2Vec() throws Exception {
|
public void testRunWord2Vec() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
// Strip white space before and after for each line
|
// Strip white space before and after for each line
|
||||||
/*val shakespear = new ClassPathResource("big/rnj.txt");
|
/*val shakespear = new ClassPathResource("big/rnj.txt");
|
||||||
SentenceIterator iter = new BasicLineIterator(shakespear.getFile());*/
|
SentenceIterator iter = new BasicLineIterator(shakespear.getFile());*/
|
||||||
|
@ -363,6 +383,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testLoadingWordVectors() throws Exception {
|
public void testLoadingWordVectors() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
File modelFile = new File(pathToWriteto);
|
File modelFile = new File(pathToWriteto);
|
||||||
if (!modelFile.exists()) {
|
if (!modelFile.exists()) {
|
||||||
testRunWord2Vec();
|
testRunWord2Vec();
|
||||||
|
@ -396,6 +421,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testW2VnegativeOnRestore() throws Exception {
|
public void testW2VnegativeOnRestore() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
// Strip white space before and after for each line
|
// Strip white space before and after for each line
|
||||||
SentenceIterator iter;
|
SentenceIterator iter;
|
||||||
if(isIntegrationTests()){
|
if(isIntegrationTests()){
|
||||||
|
@ -453,6 +483,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testUnknown1() throws Exception {
|
public void testUnknown1() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
// Strip white space before and after for each line
|
// Strip white space before and after for each line
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
// Split on white spaces in the line to get words
|
// Split on white spaces in the line to get words
|
||||||
|
@ -688,6 +723,10 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordVectorsPartiallyAbsentLabels() throws Exception {
|
public void testWordVectorsPartiallyAbsentLabels() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
// Split on white spaces in the line to get words
|
// Split on white spaces in the line to get words
|
||||||
|
@ -720,6 +759,10 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordVectorsAbsentLabels() throws Exception {
|
public void testWordVectorsAbsentLabels() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
// Split on white spaces in the line to get words
|
// Split on white spaces in the line to get words
|
||||||
|
@ -745,6 +788,10 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordVectorsAbsentLabels_WithUnknown() throws Exception {
|
public void testWordVectorsAbsentLabels_WithUnknown() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
// Split on white spaces in the line to get words
|
// Split on white spaces in the line to get words
|
||||||
|
@ -814,6 +861,10 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void weightsNotUpdated_WhenLocked_CBOW() throws Exception {
|
public void weightsNotUpdated_WhenLocked_CBOW() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
|
|
||||||
|
@ -851,6 +902,11 @@ public class Word2VecTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testWordsNearestSum() throws IOException {
|
public void testWordsNearestSum() throws IOException {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //AB 2020/02/06 Skip CUDA except for integration tests due to very slow test speed - > 5 minutes on Titan X
|
||||||
|
}
|
||||||
|
|
||||||
log.info("Load & Vectorize Sentences....");
|
log.info("Load & Vectorize Sentences....");
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile);
|
SentenceIterator iter = new BasicLineIterator(inputFile);
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
|
|
|
@ -48,12 +48,22 @@ public class TsneTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 60000L;
|
return 180000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDefaultFPDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSimple() throws Exception {
|
public void testSimple() throws Exception {
|
||||||
//Simple sanity check
|
//Simple sanity check
|
||||||
|
|
|
@ -32,6 +32,7 @@ import org.deeplearning4j.models.sequencevectors.transformers.impl.iterables.Par
|
||||||
import org.deeplearning4j.text.sentenceiterator.*;
|
import org.deeplearning4j.text.sentenceiterator.*;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
|
||||||
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
|
import org.deeplearning4j.models.embeddings.learning.impl.elements.SkipGram;
|
||||||
|
@ -80,12 +81,21 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 240000;
|
return isIntegrationTests() ? 600_000 : 240_000;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDefaultFPDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
/*
|
/*
|
||||||
@Test
|
@Test
|
||||||
|
@ -359,8 +369,13 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
@Test
|
||||||
public void testParagraphVectorsDM() throws Exception {
|
public void testParagraphVectorsDM() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed
|
||||||
|
}
|
||||||
|
|
||||||
File file = Resources.asFile("/big/raw_sentences.txt");
|
File file = Resources.asFile("/big/raw_sentences.txt");
|
||||||
SentenceIterator iter = new BasicLineIterator(file);
|
SentenceIterator iter = new BasicLineIterator(file);
|
||||||
|
|
||||||
|
@ -404,7 +419,9 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
double similarityX = vec.similarity("DOC_3720", "DOC_9852");
|
double similarityX = vec.similarity("DOC_3720", "DOC_9852");
|
||||||
log.info("3720/9852 similarity: " + similarityX);
|
log.info("3720/9852 similarity: " + similarityX);
|
||||||
|
if(isIntegrationTests()) {
|
||||||
assertTrue(similarityX < 0.5d);
|
assertTrue(similarityX < 0.5d);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// testing DM inference now
|
// testing DM inference now
|
||||||
|
@ -418,7 +435,6 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
log.info("Cos O/A: {}", cosAO1);
|
log.info("Cos O/A: {}", cosAO1);
|
||||||
log.info("Cos A/B: {}", cosAB1);
|
log.info("Cos A/B: {}", cosAB1);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -501,6 +517,11 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
@Test(timeout = 300000)
|
||||||
public void testParagraphVectorsWithWordVectorsModelling1() throws Exception {
|
public void testParagraphVectorsWithWordVectorsModelling1() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed
|
||||||
|
}
|
||||||
|
|
||||||
File file = Resources.asFile("/big/raw_sentences.txt");
|
File file = Resources.asFile("/big/raw_sentences.txt");
|
||||||
SentenceIterator iter = new BasicLineIterator(file);
|
SentenceIterator iter = new BasicLineIterator(file);
|
||||||
|
|
||||||
|
@ -705,8 +726,12 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
In this test we'll build w2v model, and will use it's vocab and weights for ParagraphVectors.
|
In this test we'll build w2v model, and will use it's vocab and weights for ParagraphVectors.
|
||||||
there's no need in this test within travis, use it manually only for problems detection
|
there's no need in this test within travis, use it manually only for problems detection
|
||||||
*/
|
*/
|
||||||
@Test(timeout = 300000)
|
@Test
|
||||||
public void testParagraphVectorsOverExistingWordVectorsModel() throws Exception {
|
public void testParagraphVectorsOverExistingWordVectorsModel() throws Exception {
|
||||||
|
String backend = Nd4j.getExecutioner().getEnvironmentInformation().getProperty("backend");
|
||||||
|
if(!isIntegrationTests() && "CUDA".equalsIgnoreCase(backend)) {
|
||||||
|
skipUnlessIntegrationTests(); //Skip CUDA except for integration tests due to very slow test speed
|
||||||
|
}
|
||||||
|
|
||||||
// we build w2v from multiple sources, to cover everything
|
// we build w2v from multiple sources, to cover everything
|
||||||
File resource_sentences = Resources.asFile("/big/raw_sentences.txt");
|
File resource_sentences = Resources.asFile("/big/raw_sentences.txt");
|
||||||
|
@ -997,14 +1022,18 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
log.info("SimilarityB: {}", simB);
|
log.info("SimilarityB: {}", simB);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 300000)
|
@Test
|
||||||
|
@Ignore //AB 2020/02/06 - https://github.com/eclipse/deeplearning4j/issues/8677
|
||||||
public void testDirectInference() throws Exception {
|
public void testDirectInference() throws Exception {
|
||||||
File resource_sentences = Resources.asFile("/big/raw_sentences.txt");
|
boolean isIntegration = isIntegrationTests();
|
||||||
|
File resource = Resources.asFile("/big/raw_sentences.txt");
|
||||||
|
SentenceIterator sentencesIter = getIterator(isIntegration, resource);
|
||||||
|
|
||||||
ClassPathResource resource_mixed = new ClassPathResource("paravec/");
|
ClassPathResource resource_mixed = new ClassPathResource("paravec/");
|
||||||
File local_resource_mixed = testDir.newFolder();
|
File local_resource_mixed = testDir.newFolder();
|
||||||
resource_mixed.copyDirectory(local_resource_mixed);
|
resource_mixed.copyDirectory(local_resource_mixed);
|
||||||
SentenceIterator iter = new AggregatingSentenceIterator.Builder()
|
SentenceIterator iter = new AggregatingSentenceIterator.Builder()
|
||||||
.addSentenceIterator(new BasicLineIterator(resource_sentences))
|
.addSentenceIterator(sentencesIter)
|
||||||
.addSentenceIterator(new FileSentenceIterator(local_resource_mixed)).build();
|
.addSentenceIterator(new FileSentenceIterator(local_resource_mixed)).build();
|
||||||
|
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
|
@ -1154,24 +1183,7 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
public void testDoubleFit() throws Exception {
|
public void testDoubleFit() throws Exception {
|
||||||
boolean isIntegration = isIntegrationTests();
|
boolean isIntegration = isIntegrationTests();
|
||||||
File resource = Resources.asFile("/big/raw_sentences.txt");
|
File resource = Resources.asFile("/big/raw_sentences.txt");
|
||||||
SentenceIterator iter;
|
SentenceIterator iter = getIterator(isIntegration, resource);
|
||||||
if(isIntegration){
|
|
||||||
iter = new BasicLineIterator(resource);
|
|
||||||
} else {
|
|
||||||
List<String> lines = new ArrayList<>();
|
|
||||||
try(InputStream is = new BufferedInputStream(new FileInputStream(resource))){
|
|
||||||
LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
|
|
||||||
try{
|
|
||||||
for( int i=0; i<500 && lineIter.hasNext(); i++ ){
|
|
||||||
lines.add(lineIter.next());
|
|
||||||
}
|
|
||||||
} finally {
|
|
||||||
lineIter.close();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
iter = new CollectionSentenceIterator(lines);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
|
@ -1197,6 +1209,30 @@ public class ParagraphVectorsTest extends BaseDL4JTest {
|
||||||
|
|
||||||
assertEquals(num1, num2);
|
assertEquals(num1, num2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static SentenceIterator getIterator(boolean isIntegration, File file) throws IOException {
|
||||||
|
return getIterator(isIntegration, file, 500);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static SentenceIterator getIterator(boolean isIntegration, File file, int linesForUnitTest) throws IOException {
|
||||||
|
if(isIntegration){
|
||||||
|
return new BasicLineIterator(file);
|
||||||
|
} else {
|
||||||
|
List<String> lines = new ArrayList<>();
|
||||||
|
try(InputStream is = new BufferedInputStream(new FileInputStream(file))){
|
||||||
|
LineIterator lineIter = IOUtils.lineIterator(is, StandardCharsets.UTF_8);
|
||||||
|
try{
|
||||||
|
for( int i=0; i<linesForUnitTest && lineIter.hasNext(); i++ ){
|
||||||
|
lines.add(lineIter.next());
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
lineIter.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new CollectionSentenceIterator(lines);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.deeplearning4j.models.word2vec;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.models.paragraphvectors.ParagraphVectorsTest;
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
|
||||||
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
import org.deeplearning4j.nn.conf.layers.DenseLayer;
|
||||||
|
@ -56,6 +57,11 @@ import static org.junit.Assert.assertEquals;
|
||||||
public class Word2VecTestsSmall extends BaseDL4JTest {
|
public class Word2VecTestsSmall extends BaseDL4JTest {
|
||||||
WordVectors word2vec;
|
WordVectors word2vec;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return isIntegrationTests() ? 240000 : 60000;
|
||||||
|
}
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() throws Exception {
|
public void setUp() throws Exception {
|
||||||
word2vec = WordVectorSerializer.readWord2VecModel(new ClassPathResource("vec.bin").getFile());
|
word2vec = WordVectorSerializer.readWord2VecModel(new ClassPathResource("vec.bin").getFile());
|
||||||
|
@ -85,8 +91,8 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
|
||||||
@Test(timeout = 300000)
|
@Test(timeout = 300000)
|
||||||
public void testUnkSerialization_1() throws Exception {
|
public void testUnkSerialization_1() throws Exception {
|
||||||
val inputFile = Resources.asFile("big/raw_sentences.txt");
|
val inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||||
|
// val iter = new BasicLineIterator(inputFile);
|
||||||
val iter = new BasicLineIterator(inputFile);
|
val iter = ParagraphVectorsTest.getIterator(isIntegrationTests(), inputFile);
|
||||||
val t = new DefaultTokenizerFactory();
|
val t = new DefaultTokenizerFactory();
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
|
||||||
|
@ -147,8 +153,8 @@ public class Word2VecTestsSmall extends BaseDL4JTest {
|
||||||
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT);
|
||||||
|
|
||||||
val inputFile = Resources.asFile("big/raw_sentences.txt");
|
val inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||||
|
val iter = ParagraphVectorsTest.getIterator(isIntegrationTests(), inputFile);
|
||||||
val iter = new BasicLineIterator(inputFile);
|
// val iter = new BasicLineIterator(inputFile);
|
||||||
val t = new DefaultTokenizerFactory();
|
val t = new DefaultTokenizerFactory();
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.deeplearning4j.models.word2vec.iterator;
|
package org.deeplearning4j.models.word2vec.iterator;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.models.paragraphvectors.ParagraphVectorsTest;
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
import org.nd4j.linalg.io.ClassPathResource;
|
||||||
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
|
import org.deeplearning4j.models.embeddings.learning.impl.elements.CBOW;
|
||||||
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
import org.deeplearning4j.models.embeddings.reader.impl.BasicModelUtils;
|
||||||
|
@ -59,7 +60,8 @@ public class Word2VecDataSetIteratorTest extends BaseDL4JTest {
|
||||||
public void testIterator1() throws Exception {
|
public void testIterator1() throws Exception {
|
||||||
|
|
||||||
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
File inputFile = Resources.asFile("big/raw_sentences.txt");
|
||||||
SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
SentenceIterator iter = ParagraphVectorsTest.getIterator(isIntegrationTests(), inputFile);
|
||||||
|
// SentenceIterator iter = new BasicLineIterator(inputFile.getAbsolutePath());
|
||||||
|
|
||||||
TokenizerFactory t = new DefaultTokenizerFactory();
|
TokenizerFactory t = new DefaultTokenizerFactory();
|
||||||
t.setTokenPreProcessor(new CommonPreprocessor());
|
t.setTokenPreProcessor(new CommonPreprocessor());
|
||||||
|
|
|
@ -147,8 +147,7 @@ public class MKLDNNBatchNormHelper implements BatchNormalizationHelper {
|
||||||
}
|
}
|
||||||
|
|
||||||
//Note: batchnorm op expects rank 1 inputs for mean/var etc, not rank 2 shape [1,x]
|
//Note: batchnorm op expects rank 1 inputs for mean/var etc, not rank 2 shape [1,x]
|
||||||
context.getInputArrays().clear();
|
context.purge();
|
||||||
context.getOutputArrays().clear();
|
|
||||||
context.setInputArray(0, x);
|
context.setInputArray(0, x);
|
||||||
context.setInputArray(1, m);
|
context.setInputArray(1, m);
|
||||||
context.setInputArray(2, v);
|
context.setInputArray(2, v);
|
||||||
|
|
|
@ -89,8 +89,7 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
||||||
|
|
||||||
INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta};
|
INDArray[] inputsArr = biasGradView == null ? new INDArray[]{input, weightsPermute, delta} : new INDArray[]{input, weightsPermute, bias, delta};
|
||||||
INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView};
|
INDArray[] outputArr = biasGradView == null ? new INDArray[]{gradAtInput, weightGradViewPermute} : new INDArray[]{gradAtInput, weightGradViewPermute, biasGradView};
|
||||||
contextBwd.getInputArrays().clear();
|
contextBwd.purge();
|
||||||
contextBwd.getOutputArrays().clear();
|
|
||||||
for( int i=0; i<inputsArr.length; i++ ){
|
for( int i=0; i<inputsArr.length; i++ ){
|
||||||
contextBwd.setInputArray(i, inputsArr[i]);
|
contextBwd.setInputArray(i, inputsArr[i]);
|
||||||
}
|
}
|
||||||
|
@ -100,8 +99,6 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
||||||
|
|
||||||
Conv2DDerivative op = new Conv2DDerivative();
|
Conv2DDerivative op = new Conv2DDerivative();
|
||||||
Nd4j.exec(op, contextBwd);
|
Nd4j.exec(op, contextBwd);
|
||||||
contextBwd.getInputArrays().clear();
|
|
||||||
contextBwd.getOutputArrays().clear();
|
|
||||||
|
|
||||||
Gradient g = new DefaultGradient();
|
Gradient g = new DefaultGradient();
|
||||||
if(biasGradView != null) {
|
if(biasGradView != null) {
|
||||||
|
@ -145,16 +142,14 @@ public class MKLDNNConvHelper implements ConvolutionHelper {
|
||||||
weights = weights.permute(2,3,1,0);
|
weights = weights.permute(2,3,1,0);
|
||||||
|
|
||||||
INDArray[] inputsArr = bias == null ? new INDArray[]{input, weights} : new INDArray[]{input, weights, bias};
|
INDArray[] inputsArr = bias == null ? new INDArray[]{input, weights} : new INDArray[]{input, weights, bias};
|
||||||
context.getInputArrays().clear();
|
context.purge();
|
||||||
for( int i=0; i<inputsArr.length; i++ ){
|
for( int i=0; i<inputsArr.length; i++ ){
|
||||||
context.setInputArray(i, inputsArr[i]);
|
context.setInputArray(i, inputsArr[i]);
|
||||||
}
|
}
|
||||||
context.getOutputArrays().clear();
|
|
||||||
context.setOutputArray(0, out);
|
context.setOutputArray(0, out);
|
||||||
Conv2D op = new Conv2D();
|
Conv2D op = new Conv2D();
|
||||||
Nd4j.exec(op, context);
|
Nd4j.exec(op, context);
|
||||||
context.getInputArrays().clear();
|
|
||||||
context.getOutputArrays().clear();
|
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,7 +59,8 @@ public class MKLDNNLocalResponseNormalizationHelper extends BaseMKLDNNHelper imp
|
||||||
context = Nd4j.getExecutioner().buildContext();
|
context = Nd4j.getExecutioner().buildContext();
|
||||||
context.setTArguments(k, alpha, beta);
|
context.setTArguments(k, alpha, beta);
|
||||||
context.setIArguments((int)n);
|
context.setIArguments((int)n);
|
||||||
}
|
} else
|
||||||
|
context.purge();
|
||||||
|
|
||||||
LocalResponseNormalization op = new LocalResponseNormalization();
|
LocalResponseNormalization op = new LocalResponseNormalization();
|
||||||
|
|
||||||
|
@ -80,7 +81,8 @@ public class MKLDNNLocalResponseNormalizationHelper extends BaseMKLDNNHelper imp
|
||||||
context = Nd4j.getExecutioner().buildContext();
|
context = Nd4j.getExecutioner().buildContext();
|
||||||
context.setTArguments(k, alpha, beta);
|
context.setTArguments(k, alpha, beta);
|
||||||
context.setIArguments((int)n);
|
context.setIArguments((int)n);
|
||||||
}
|
} else
|
||||||
|
context.purge();
|
||||||
|
|
||||||
context.setInputArray(0, x);
|
context.setInputArray(0, x);
|
||||||
context.setOutputArray(0, out);
|
context.setOutputArray(0, out);
|
||||||
|
|
|
@ -132,13 +132,12 @@ public class MKLDNNSubsamplingHelper implements SubsamplingHelper {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
context.getInputArrays().clear();
|
context.purge();
|
||||||
context.getOutputArrays().clear();
|
|
||||||
|
|
||||||
context.setInputArray(0, input);
|
context.setInputArray(0, input);
|
||||||
context.setOutputArray(0, output);
|
context.setOutputArray(0, output);
|
||||||
|
|
||||||
Nd4j.exec(op, context);
|
Nd4j.exec(op, context);
|
||||||
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -282,6 +282,12 @@ public class SimpleRnn extends BaseRecurrentLayer<org.deeplearning4j.nn.conf.lay
|
||||||
|
|
||||||
a.getActivation(currOut, training);
|
a.getActivation(currOut, training);
|
||||||
|
|
||||||
|
if( maskArray != null){
|
||||||
|
//If mask array is present: Also need to zero out errors to avoid sending anything but 0s to layer below for masked steps
|
||||||
|
INDArray maskCol = maskArray.getColumn(i, true).castTo(dataType);
|
||||||
|
currOut.muliColumnVector(maskCol);
|
||||||
|
}
|
||||||
|
|
||||||
prevStepOut = currOut;
|
prevStepOut = currOut;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -58,6 +58,7 @@ import org.nd4j.evaluation.classification.Evaluation;
|
||||||
import org.nd4j.evaluation.classification.ROC;
|
import org.nd4j.evaluation.classification.ROC;
|
||||||
import org.nd4j.evaluation.classification.ROCMultiClass;
|
import org.nd4j.evaluation.classification.ROCMultiClass;
|
||||||
import org.nd4j.linalg.activations.Activation;
|
import org.nd4j.linalg.activations.Activation;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||||
|
@ -93,7 +94,23 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
@Rule
|
@Rule
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
|
||||||
|
@Override
|
||||||
|
public long getTimeoutMilliseconds() {
|
||||||
|
return 120000L;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDefaultFPDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public DataType getDataType() {
|
||||||
|
return DataType.FLOAT;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
public void testFromSvmLightBackprop() throws Exception {
|
public void testFromSvmLightBackprop() throws Exception {
|
||||||
JavaRDD<LabeledPoint> data = MLUtils
|
JavaRDD<LabeledPoint> data = MLUtils
|
||||||
.loadLibSVMFile(sc.sc(),
|
.loadLibSVMFile(sc.sc(),
|
||||||
|
@ -125,7 +142,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testFromSvmLight() throws Exception {
|
public void testFromSvmLight() throws Exception {
|
||||||
JavaRDD<LabeledPoint> data = MLUtils
|
JavaRDD<LabeledPoint> data = MLUtils
|
||||||
.loadLibSVMFile(sc.sc(),
|
.loadLibSVMFile(sc.sc(),
|
||||||
|
@ -155,7 +172,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
master.fitLabeledPoint(data);
|
master.fitLabeledPoint(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testRunIteration() {
|
public void testRunIteration() {
|
||||||
|
|
||||||
DataSet dataSet = new IrisDataSetIterator(5, 5).next();
|
DataSet dataSet = new IrisDataSetIterator(5, 5).next();
|
||||||
|
@ -175,7 +192,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
assertEquals(expectedParams.size(1), actualParams.size(1));
|
assertEquals(expectedParams.size(1), actualParams.size(1));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testUpdaters() {
|
public void testUpdaters() {
|
||||||
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
|
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
|
||||||
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
|
MultiLayerNetwork netCopy = sparkNet.getNetwork().clone();
|
||||||
|
@ -197,7 +214,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testEvaluation() {
|
public void testEvaluation() {
|
||||||
|
|
||||||
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
|
SparkDl4jMultiLayer sparkNet = getBasicNetwork();
|
||||||
|
@ -228,7 +245,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testSmallAmountOfData() {
|
public void testSmallAmountOfData() {
|
||||||
//Idea: Test spark training where some executors don't get any data
|
//Idea: Test spark training where some executors don't get any data
|
||||||
//in this case: by having fewer examples (2 DataSets) than executors (local[*])
|
//in this case: by having fewer examples (2 DataSets) than executors (local[*])
|
||||||
|
@ -255,7 +272,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testDistributedScoring() {
|
public void testDistributedScoring() {
|
||||||
|
|
||||||
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.1)
|
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().l1(0.1).l2(0.1)
|
||||||
|
@ -333,7 +350,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testParameterAveragingMultipleExamplesPerDataSet() throws Exception {
|
public void testParameterAveragingMultipleExamplesPerDataSet() throws Exception {
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 25;
|
int batchSizePerExecutor = 25;
|
||||||
|
@ -382,7 +399,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testFitViaStringPaths() throws Exception {
|
public void testFitViaStringPaths() throws Exception {
|
||||||
|
|
||||||
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPaths").toPath();
|
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPaths").toPath();
|
||||||
|
@ -445,7 +462,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
sparkNet.getTrainingMaster().deleteTempFiles(sc);
|
sparkNet.getTrainingMaster().deleteTempFiles(sc);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testFitViaStringPathsSize1() throws Exception {
|
public void testFitViaStringPathsSize1() throws Exception {
|
||||||
|
|
||||||
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsSize1").toPath();
|
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsSize1").toPath();
|
||||||
|
@ -525,7 +542,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testFitViaStringPathsCompGraph() throws Exception {
|
public void testFitViaStringPathsCompGraph() throws Exception {
|
||||||
|
|
||||||
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsCG").toPath();
|
Path tempDir = testDir.newFolder("DL4J-testFitViaStringPathsCG").toPath();
|
||||||
|
@ -618,7 +635,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
@Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue")
|
@Ignore("AB 2019/05/23 - Failing on CI only - passing locally. Possible precision or threading issue")
|
||||||
public void testSeedRepeatability() throws Exception {
|
public void testSeedRepeatability() throws Exception {
|
||||||
|
|
||||||
|
@ -691,7 +708,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testIterationCounts() throws Exception {
|
public void testIterationCounts() throws Exception {
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 25;
|
int batchSizePerExecutor = 25;
|
||||||
|
@ -737,7 +754,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testIterationCountsGraph() throws Exception {
|
public void testIterationCountsGraph() throws Exception {
|
||||||
int dataSetObjSize = 5;
|
int dataSetObjSize = 5;
|
||||||
int batchSizePerExecutor = 25;
|
int batchSizePerExecutor = 25;
|
||||||
|
@ -783,7 +800,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L) @Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
|
@Test
|
||||||
|
@Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
|
||||||
public void testVaePretrainSimple() {
|
public void testVaePretrainSimple() {
|
||||||
//Simple sanity check on pretraining
|
//Simple sanity check on pretraining
|
||||||
int nIn = 8;
|
int nIn = 8;
|
||||||
|
@ -818,7 +836,8 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
sparkNet.fit(data);
|
sparkNet.fit(data);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(timeout = 120000L) @Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
|
@Test
|
||||||
|
@Ignore //Ignored 2019/04/09 - low priority: https://github.com/deeplearning4j/deeplearning4j/issues/6656
|
||||||
public void testVaePretrainSimpleCG() {
|
public void testVaePretrainSimpleCG() {
|
||||||
//Simple sanity check on pretraining
|
//Simple sanity check on pretraining
|
||||||
int nIn = 8;
|
int nIn = 8;
|
||||||
|
@ -854,7 +873,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testROC() {
|
public void testROC() {
|
||||||
|
|
||||||
int nArrays = 100;
|
int nArrays = 100;
|
||||||
|
@ -909,7 +928,7 @@ public class TestSparkMultiLayerParameterAveraging extends BaseSparkTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@Test(timeout = 120000L)
|
@Test
|
||||||
public void testROCMultiClass() {
|
public void testROCMultiClass() {
|
||||||
|
|
||||||
int nArrays = 100;
|
int nArrays = 100;
|
||||||
|
|
|
@ -34,7 +34,7 @@ public class MiscTests extends BaseDL4JTest {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long getTimeoutMilliseconds() {
|
public long getTimeoutMilliseconds() {
|
||||||
return 120000L;
|
return 240000L;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -380,7 +380,7 @@
|
||||||
-->
|
-->
|
||||||
<useSystemClassLoader>true</useSystemClassLoader>
|
<useSystemClassLoader>true</useSystemClassLoader>
|
||||||
<useManifestOnlyJar>false</useManifestOnlyJar>
|
<useManifestOnlyJar>false</useManifestOnlyJar>
|
||||||
<argLine>-Ddtype=float -Xmx8g</argLine>
|
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g</argLine>
|
||||||
<includes>
|
<includes>
|
||||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
<!-- Default setting only runs tests that start/end with "Test" -->
|
||||||
<include>*.java</include>
|
<include>*.java</include>
|
||||||
|
|
|
@ -1601,6 +1601,7 @@ ND4J_EXPORT OpaqueRandomGenerator* getGraphContextRandomGenerator(OpaqueContext*
|
||||||
ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow);
|
ND4J_EXPORT void ctxAllowHelpers(OpaqueContext* ptr, bool reallyAllow);
|
||||||
ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride);
|
ND4J_EXPORT void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride);
|
||||||
ND4J_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode);
|
ND4J_EXPORT void ctxSetExecutionMode(OpaqueContext* ptr, int execMode);
|
||||||
|
ND4J_EXPORT void ctxPurge(OpaqueContext* ptr);
|
||||||
ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace);
|
ND4J_EXPORT void markGraphContextInplace(OpaqueContext* ptr, bool reallyInplace);
|
||||||
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
|
ND4J_EXPORT void setGraphContextCudaContext(OpaqueContext* ptr, void *stream, void *reductionPointer, void *allocationPointer);
|
||||||
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
ND4J_EXPORT void setGraphContextInputArray(OpaqueContext* ptr, int index, void *buffer, void *shapeInfo, void *specialBuffer, void *specialShapeInfo);
|
||||||
|
|
|
@ -2815,6 +2815,10 @@ void ctxSetExecutionMode(OpaqueContext* ptr, int execMode) {
|
||||||
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
|
ptr->setExecutionMode((samediff::ExecutionMode) execMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ctxPurge(OpaqueContext* ptr) {
|
||||||
|
ptr->clearFastPath();
|
||||||
|
}
|
||||||
|
|
||||||
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
nd4j::graph::RandomGenerator* createRandomGenerator(Nd4jLong rootSeed, Nd4jLong nodeSeed) {
|
||||||
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
|
return new nd4j::graph::RandomGenerator(rootSeed, nodeSeed);
|
||||||
}
|
}
|
||||||
|
|
|
@ -3771,6 +3771,10 @@ void ctxShapeFunctionOverride(OpaqueContext* ptr, bool reallyOverride) {
|
||||||
ptr->setShapeFunctionOverride(reallyOverride);
|
ptr->setShapeFunctionOverride(reallyOverride);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ctxPurge(OpaqueContext* ptr) {
|
||||||
|
ptr->clearFastPath();
|
||||||
|
}
|
||||||
|
|
||||||
int binaryLevel() {
|
int binaryLevel() {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -305,12 +305,17 @@ namespace nd4j {
|
||||||
if (_primaryBuffer != nullptr && _isOwnerPrimary) {
|
if (_primaryBuffer != nullptr && _isOwnerPrimary) {
|
||||||
deletePrimary();
|
deletePrimary();
|
||||||
}
|
}
|
||||||
|
|
||||||
_primaryBuffer = buffer;
|
_primaryBuffer = buffer;
|
||||||
_isOwnerPrimary = false;
|
_isOwnerPrimary = false;
|
||||||
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
|
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
|
||||||
}
|
}
|
||||||
|
|
||||||
void DataBuffer::setSpecialBuffer(void *buffer, size_t length) {
|
void DataBuffer::setSpecialBuffer(void *buffer, size_t length) {
|
||||||
|
if (_specialBuffer != nullptr && _isOwnerSpecial) {
|
||||||
|
deleteSpecial();
|
||||||
|
}
|
||||||
|
|
||||||
this->setSpecial(buffer, false);
|
this->setSpecial(buffer, false);
|
||||||
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
|
_lenInBytes = length * DataTypeUtils::sizeOf(_dataType);
|
||||||
}
|
}
|
||||||
|
|
|
@ -204,6 +204,13 @@ namespace nd4j {
|
||||||
void setBArguments(const std::vector<bool> &tArgs);
|
void setBArguments(const std::vector<bool> &tArgs);
|
||||||
void setDArguments(const std::vector<nd4j::DataType> &dArgs);
|
void setDArguments(const std::vector<nd4j::DataType> &dArgs);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method purges fastpath in/out contents and releases all the handles.
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: I/T/B/D args will stay intact
|
||||||
|
*/
|
||||||
|
void clearFastPath();
|
||||||
|
|
||||||
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);
|
void setCudaContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer allocationPointer);
|
||||||
|
|
||||||
void allowHelpers(bool reallyAllow);
|
void allowHelpers(bool reallyAllow);
|
||||||
|
|
|
@ -563,6 +563,16 @@ namespace nd4j {
|
||||||
for (auto d:dArgs)
|
for (auto d:dArgs)
|
||||||
_dArgs.emplace_back(d);
|
_dArgs.emplace_back(d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void Context::clearFastPath() {
|
||||||
|
_fastpath_in.clear();
|
||||||
|
_fastpath_out.clear();
|
||||||
|
|
||||||
|
for (auto v:_handles)
|
||||||
|
delete v;
|
||||||
|
|
||||||
|
_handles.clear();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -92,7 +92,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) {
|
void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) {
|
||||||
_mutex.lock();
|
std::lock_guard<std::mutex> lock(_mutex);
|
||||||
|
|
||||||
auto deviceId = getCurrentDevice();
|
auto deviceId = getCurrentDevice();
|
||||||
Nd4jPointer constantPtr = nullptr;
|
Nd4jPointer constantPtr = nullptr;
|
||||||
|
@ -116,7 +116,6 @@ namespace nd4j {
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw cuda_exception::build("cudaMemcpy failed", res);
|
throw cuda_exception::build("cudaMemcpy failed", res);
|
||||||
|
|
||||||
_mutex.unlock();
|
|
||||||
return ptr;
|
return ptr;
|
||||||
} else {
|
} else {
|
||||||
auto originalBytes = numBytes;
|
auto originalBytes = numBytes;
|
||||||
|
@ -130,7 +129,6 @@ namespace nd4j {
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw cuda_exception::build("cudaMemcpyToSymbol failed", res);
|
throw cuda_exception::build("cudaMemcpyToSymbol failed", res);
|
||||||
|
|
||||||
_mutex.unlock();
|
|
||||||
return reinterpret_cast<int8_t *>(constantPtr) + constantOffset;
|
return reinterpret_cast<int8_t *>(constantPtr) + constantOffset;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -152,7 +150,7 @@ namespace nd4j {
|
||||||
ConstantDataBuffer* result;
|
ConstantDataBuffer* result;
|
||||||
|
|
||||||
// access to this holder instance is synchronous
|
// access to this holder instance is synchronous
|
||||||
holder->mutex()->lock();
|
std::lock_guard<std::mutex> lock(*holder->mutex());
|
||||||
|
|
||||||
if (holder->hasBuffer(dataType)) {
|
if (holder->hasBuffer(dataType)) {
|
||||||
result = holder->getConstantDataBuffer(dataType);
|
result = holder->getConstantDataBuffer(dataType);
|
||||||
|
@ -175,8 +173,6 @@ namespace nd4j {
|
||||||
holder->addBuffer(dataBuffer, dataType);
|
holder->addBuffer(dataBuffer, dataType);
|
||||||
result = holder->getConstantDataBuffer(dataType);
|
result = holder->getConstantDataBuffer(dataType);
|
||||||
}
|
}
|
||||||
// release holder lock
|
|
||||||
holder->mutex()->unlock();
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -57,7 +57,7 @@ namespace nd4j {
|
||||||
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
ConstantDataBuffer ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||||
int deviceId = AffinityManager::currentDeviceId();
|
int deviceId = AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
_mutex.lock();
|
std::lock_guard<std::mutex> lock(_mutex);
|
||||||
|
|
||||||
if (_cache[deviceId].count(descriptor) == 0) {
|
if (_cache[deviceId].count(descriptor) == 0) {
|
||||||
auto hPtr = descriptor.toShapeInfo();
|
auto hPtr = descriptor.toShapeInfo();
|
||||||
|
@ -65,15 +65,9 @@ namespace nd4j {
|
||||||
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
|
ConstantDataBuffer buffer(hPtr, dPtr, shape::shapeInfoLength(hPtr) * sizeof(Nd4jLong), DataType::INT64);
|
||||||
ShapeDescriptor descriptor1(descriptor);
|
ShapeDescriptor descriptor1(descriptor);
|
||||||
_cache[deviceId][descriptor1] = buffer;
|
_cache[deviceId][descriptor1] = buffer;
|
||||||
auto r = _cache[deviceId][descriptor1];
|
return _cache[deviceId][descriptor1];
|
||||||
_mutex.unlock();
|
|
||||||
|
|
||||||
return r;
|
|
||||||
} else {
|
} else {
|
||||||
ConstantDataBuffer r = _cache[deviceId].at(descriptor);
|
return _cache[deviceId].at(descriptor);
|
||||||
_mutex.unlock();
|
|
||||||
|
|
||||||
return r;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,18 +77,10 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) {
|
bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) {
|
||||||
bool result;
|
|
||||||
auto deviceId = AffinityManager::currentDeviceId();
|
auto deviceId = AffinityManager::currentDeviceId();
|
||||||
_mutex.lock();
|
std::lock_guard<std::mutex> lock(_mutex);
|
||||||
|
|
||||||
if (_cache[deviceId].count(descriptor) == 0)
|
return _cache[deviceId].count(descriptor) != 0;
|
||||||
result = false;
|
|
||||||
else
|
|
||||||
result = true;
|
|
||||||
|
|
||||||
_mutex.unlock();
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
Nd4jLong* ConstantShapeHelper::createShapeInfo(const nd4j::DataType dataType, const char order, const int rank, const Nd4jLong* shape) {
|
||||||
|
|
|
@ -64,7 +64,7 @@ namespace nd4j {
|
||||||
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
TadPack ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
||||||
const int deviceId = AffinityManager::currentDeviceId();
|
const int deviceId = AffinityManager::currentDeviceId();
|
||||||
|
|
||||||
_mutex.lock();
|
std::lock_guard<std::mutex> lock(_mutex);
|
||||||
|
|
||||||
if (_cache[deviceId].count(descriptor) == 0) {
|
if (_cache[deviceId].count(descriptor) == 0) {
|
||||||
const auto shapeInfo = descriptor.originalShape().toShapeInfo();
|
const auto shapeInfo = descriptor.originalShape().toShapeInfo();
|
||||||
|
@ -97,14 +97,12 @@ namespace nd4j {
|
||||||
_cache[deviceId][descriptor] = t;
|
_cache[deviceId][descriptor] = t;
|
||||||
|
|
||||||
TadPack r = _cache[deviceId][descriptor];
|
TadPack r = _cache[deviceId][descriptor];
|
||||||
_mutex.unlock();
|
|
||||||
|
|
||||||
delete[] shapeInfo;
|
delete[] shapeInfo;
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
} else {
|
} else {
|
||||||
TadPack r = _cache[deviceId][descriptor];
|
TadPack r = _cache[deviceId][descriptor];
|
||||||
_mutex.unlock();
|
|
||||||
|
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
|
@ -169,8 +169,8 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
||||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
@ -178,8 +178,8 @@ CUSTOM_OP_IMPL(maxpool3dnew_bp, 2, 1, false, 0, 14) {
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,iD,iH,iW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2}));
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "MAXPOOL3DNEW_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0, "MAXPOOL3DNEW_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !", expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
||||||
|
|
||||||
if(!isNCDHW) {
|
if(!isNCDHW) {
|
||||||
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
input = new NDArray(input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
||||||
|
|
|
@ -58,30 +58,31 @@ namespace nd4j {
|
||||||
int outRank = shape::rank(in) + 1;
|
int outRank = shape::rank(in) + 1;
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto dtype = DataType::BOOL;
|
auto dtype = DataType::BOOL;
|
||||||
Nd4jLong maxInd = input->argMax();
|
auto argMaxInd = input->argMax();
|
||||||
Nd4jLong max = input->e<Nd4jLong>(maxInd);
|
Nd4jLong max = input->e<Nd4jLong>(argMaxInd);
|
||||||
|
Nd4jLong maxInd = max;
|
||||||
|
|
||||||
if (block.getIArguments()->size() > 0) {
|
if (block.numD() > 0)
|
||||||
if (block.width() < 2) {
|
dtype = D_ARG(0);
|
||||||
maxInd = INT_ARG(0);
|
|
||||||
if (maxInd < max)
|
|
||||||
maxInd = static_cast<Nd4jLong>(max);
|
|
||||||
if (block.getIArguments()->size() > 1)
|
|
||||||
dtype = (DataType)INT_ARG(1);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
dtype = (DataType)INT_ARG(0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (block.width() > 1) {
|
if (block.width() > 1) {
|
||||||
auto maxlen = INPUT_VARIABLE(1);
|
auto maxlen = INPUT_VARIABLE(1);
|
||||||
Nd4jLong tmaxlen = maxlen->e<Nd4jLong>(0);
|
Nd4jLong tmaxlen = maxlen->e<Nd4jLong>(0);
|
||||||
if (tmaxlen > max)
|
if (tmaxlen > max)
|
||||||
maxInd = static_cast<Nd4jLong>(tmaxlen);
|
maxInd = static_cast<Nd4jLong>(tmaxlen);
|
||||||
|
if (block.numI() > 0) {
|
||||||
|
dtype = (DataType) INT_ARG(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (block.numI() > 0) {
|
||||||
|
maxInd = INT_ARG(0);
|
||||||
|
}
|
||||||
|
if (maxInd < max)
|
||||||
|
maxInd = max;
|
||||||
|
if (block.numI() > 1)
|
||||||
|
dtype = (DataType)INT_ARG(1); // to work with legacy code
|
||||||
}
|
}
|
||||||
else
|
|
||||||
maxInd = static_cast<Nd4jLong>(max);
|
|
||||||
|
|
||||||
int lastDimension = maxInd;
|
int lastDimension = maxInd;
|
||||||
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);
|
ALLOCATE(outShapeInfo, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong);
|
||||||
|
|
|
@ -38,10 +38,10 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -36,10 +36,12 @@ namespace helpers {
|
||||||
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
static void adjointMatrix_(nd4j::LaunchContext* context, NDArray const* input, NDArray* output) {
|
||||||
auto inputPart = input->allTensorsAlongDimension({-2, -1});
|
auto inputPart = input->allTensorsAlongDimension({-2, -1});
|
||||||
auto outputPart = output->allTensorsAlongDimension({-2, -1});
|
auto outputPart = output->allTensorsAlongDimension({-2, -1});
|
||||||
|
auto rows = input->sizeAt(-2);
|
||||||
output->assign(input);
|
output->assign(input);
|
||||||
|
|
||||||
auto batchLoop = PRAGMA_THREADS_FOR {
|
auto batchLoop = PRAGMA_THREADS_FOR {
|
||||||
for (auto batch = start; batch < stop; batch += increment) {
|
for (auto batch = start; batch < stop; batch += increment) {
|
||||||
for (auto r = 0; r < input->rows(); r++) {
|
for (auto r = 0; r < rows; r++) {
|
||||||
for (auto c = 0; c < r; c++) {
|
for (auto c = 0; c < r; c++) {
|
||||||
math::nd4j_swap(outputPart[batch]->t<T>(r, c) , outputPart[batch]->t<T>(c, r));
|
math::nd4j_swap(outputPart[batch]->t<T>(r, c) , outputPart[batch]->t<T>(c, r));
|
||||||
}
|
}
|
||||||
|
|
|
@ -108,17 +108,20 @@ namespace helpers {
|
||||||
static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) {
|
static void adjointTriangularMatrix_(nd4j::LaunchContext* context, NDArray const* input, bool const lower, NDArray* output) {
|
||||||
auto inputPart = input->allTensorsAlongDimension({-2, -1});
|
auto inputPart = input->allTensorsAlongDimension({-2, -1});
|
||||||
auto outputPart = output->allTensorsAlongDimension({-2, -1});
|
auto outputPart = output->allTensorsAlongDimension({-2, -1});
|
||||||
|
auto cols = input->sizeAt(-1);
|
||||||
|
auto rows = input->sizeAt(-2);
|
||||||
|
|
||||||
auto batchLoop = PRAGMA_THREADS_FOR {
|
auto batchLoop = PRAGMA_THREADS_FOR {
|
||||||
for (auto batch = start; batch < stop; batch += increment) {
|
for (auto batch = start; batch < stop; batch += increment) {
|
||||||
if (!lower) {
|
if (!lower) {
|
||||||
for (auto r = 0; r < input->rows(); r++) {
|
for (auto r = 0; r < rows; r++) {
|
||||||
for (auto c = 0; c <= r; c++) {
|
for (auto c = 0; c <= r; c++) {
|
||||||
outputPart[batch]->t<T>(r, c) = inputPart[batch]->t<T>(c, r);
|
outputPart[batch]->t<T>(r, c) = inputPart[batch]->t<T>(c, r);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (auto r = 0; r < input->rows(); r++) {
|
for (auto r = 0; r < rows; r++) {
|
||||||
for (auto c = r; c < input->columns(); c++) {
|
for (auto c = r; c < cols; c++) {
|
||||||
outputPart[batch]->t<T>(r, c) = inputPart[batch]->t<T>(c, r);
|
outputPart[batch]->t<T>(r, c) = inputPart[batch]->t<T>(c, r);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -55,10 +55,10 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
void sequenceMask(nd4j::LaunchContext * context, NDArray* input, NDArray* output, int maxIndex) {
|
||||||
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_SELECTOR(input->dataType(), output->dataType(), sequenceMask_, (context, input, output, maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, BOOL_TYPES);
|
BUILD_DOUBLE_TEMPLATE(template void sequenceMask_, (nd4j::LaunchContext* context, NDArray* input, NDArray* output, int maxIndex), INTEGER_TYPES, LIBND4J_TYPES_EXTENDED);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -250,7 +250,7 @@ void pooling3dCUDNN(const LaunchContext* context,
|
||||||
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
auto handle = reinterpret_cast<cudnnHandle_t *>(context->getCuDnnHandle());
|
||||||
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
cudnnStatus_t err = cudnnSetStream(*handle, *context->getCudaStream());
|
||||||
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: can't set stream for cuDNN", err);
|
if (err != 0) throw nd4j::cuda_exception::build("pooling3dCUDNN: can't set stream for cuDNN", err);
|
||||||
printf("fffffffffff\n");
|
|
||||||
const int numDims = 5;
|
const int numDims = 5;
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
//
|
//
|
||||||
// @author saudet
|
// @author saudet
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/PlatformHelper.h>
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
@ -36,103 +37,44 @@ namespace platforms {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
|
PLATFORM_IMPL(avgpool2d, ENGINE_CPU) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
auto input = INPUT_VARIABLE(0);
|
||||||
input->rankOf());
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
auto argI = *(block.getIArguments());
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
const auto kH = INT_ARG(0);
|
const auto kH = INT_ARG(0);
|
||||||
const auto kW = INT_ARG(1);
|
const auto kW = INT_ARG(1);
|
||||||
const auto sH = INT_ARG(2);
|
const auto sH = INT_ARG(2);
|
||||||
const auto sW = INT_ARG(3);
|
const auto sW = INT_ARG(3);
|
||||||
int pH = INT_ARG(4);
|
auto pH = INT_ARG(4);
|
||||||
int pW = INT_ARG(5);
|
auto pW = INT_ARG(5);
|
||||||
const auto dH = INT_ARG(6);
|
const auto dH = INT_ARG(6);
|
||||||
const auto dW = INT_ARG(7);
|
const auto dW = INT_ARG(7);
|
||||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
const auto paddingMode = INT_ARG(8);
|
||||||
const auto extraParam0 = INT_ARG(9);
|
const auto extraParam0 = INT_ARG(9);
|
||||||
|
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
dH, dW);
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
|
|
||||||
int oH = 0;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int oW = 0;
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
if (paddingMode)
|
||||||
|
|
||||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
|
||||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
|
||||||
input = new NDArray(
|
|
||||||
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
output = new NDArray(
|
|
||||||
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
|
||||||
|
|
||||||
if (isSameMode)
|
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
const int bS = input->sizeAt(0);
|
auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding;
|
||||||
const int iC = input->sizeAt(1);
|
|
||||||
const int oC = output->sizeAt(1);
|
|
||||||
|
|
||||||
auto poolingMode = PoolingType::AVG_POOL;
|
mkldnnUtils::poolingMKLDNN(input, output, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, mode);
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
|
||||||
true,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
|
||||||
algorithm,
|
|
||||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
|
||||||
&user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
|
||||||
pool_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
|
||||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
|
||||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
|
||||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
|
||||||
}
|
|
||||||
auto pool_dst_memory = user_dst_memory;
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
|
||||||
}
|
|
||||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
|
||||||
{DNNL_ARG_DST, pool_dst_memory}});
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
|
||||||
}
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
//streams[0].submitAndWait();
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
|
||||||
delete input;
|
|
||||||
delete output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
|
PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
@ -141,12 +83,10 @@ PLATFORM_CHECK(avgpool2d, ENGINE_CPU) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
|
PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
auto gradO = INPUT_VARIABLE(
|
auto gradO = INPUT_VARIABLE(1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
1); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
|
||||||
|
|
||||||
int kH = INT_ARG(0); // filter(kernel) height
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
int kW = INT_ARG(1); // filter(kernel) width
|
int kW = INT_ARG(1); // filter(kernel) width
|
||||||
|
@ -156,92 +96,26 @@ PLATFORM_IMPL(avgpool2d_bp, ENGINE_CPU) {
|
||||||
int pW = INT_ARG(5); // paddings width
|
int pW = INT_ARG(5); // paddings width
|
||||||
int dH = INT_ARG(6); // dilations height
|
int dH = INT_ARG(6); // dilations height
|
||||||
int dW = INT_ARG(7); // dilations width
|
int dW = INT_ARG(7); // dilations width
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
int extraParam0 = INT_ARG(9);
|
int extraParam0 = INT_ARG(9);
|
||||||
int isNCHW =
|
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
|
||||||
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oH,oW, 0,indIOioC,indIiH,indIiH+1});
|
||||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
|
||||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
|
||||||
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
|
||||||
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
|
||||||
|
|
||||||
|
if(paddingMode) // SAME
|
||||||
if (!isNCHW) {
|
|
||||||
input = new NDArray(input->permute(
|
|
||||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
gradI = new NDArray(gradI->permute(
|
|
||||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
gradO = new NDArray(gradO->permute(
|
|
||||||
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
auto poolingMode = PoolingType::AVG_POOL;
|
auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding;
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
|
||||||
true,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
|
||||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
|
||||||
&user_diff_src_md, &user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
|
||||||
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
|
||||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
|
||||||
pool_padding_r);
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
|
|
||||||
pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
|
||||||
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
|
||||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
|
||||||
auto poolB_src_memory = userB_src_memory;
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
|
||||||
}
|
|
||||||
auto poolB_dst_memory = userB_dst_memory;
|
|
||||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
|
||||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
|
||||||
}
|
|
||||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
|
||||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
|
||||||
}
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
|
||||||
delete input;
|
|
||||||
delete gradI;
|
|
||||||
delete gradO;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, mode);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
//
|
//
|
||||||
// @author saudet
|
// @author saudet
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/PlatformHelper.h>
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
@ -32,11 +33,12 @@ using namespace dnnl;
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) {
|
PLATFORM_IMPL(avgpool3dnew, ENGINE_CPU) {
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
auto output = OUTPUT_VARIABLE(
|
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||||
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
|
||||||
|
|
||||||
int kD = INT_ARG(0); // filter(kernel) depth
|
int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
int kH = INT_ARG(1); // filter(kernel) height
|
int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
@ -50,92 +52,88 @@ namespace nd4j {
|
||||||
int dD = INT_ARG(9); // dilations depth
|
int dD = INT_ARG(9); // dilations depth
|
||||||
int dH = INT_ARG(10); // dilations height
|
int dH = INT_ARG(10); // dilations height
|
||||||
int dW = INT_ARG(11); // dilations width
|
int dW = INT_ARG(11); // dilations width
|
||||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
int extraParam0 = INT_ARG(13);
|
||||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||||
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW MKLDNN OP: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
input->rankOf());
|
|
||||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
|
||||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
|
||||||
|
|
||||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
if(paddingMode) // SAME
|
||||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
|
||||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
|
|
||||||
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
|
|
||||||
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
|
||||||
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
|
||||||
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
|
||||||
|
|
||||||
if (!isNCDHW) {
|
|
||||||
input = new NDArray(
|
|
||||||
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
output = new NDArray(
|
|
||||||
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
|
auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding;
|
||||||
|
|
||||||
auto poolingMode = PoolingType::AVG_POOL;
|
mkldnnUtils::poolingMKLDNN(input, output, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, mode);
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
|
||||||
extraParam0, true,
|
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
|
|
||||||
algorithm,
|
|
||||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
|
||||||
&user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
|
||||||
pool_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
|
||||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
|
||||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
|
||||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
|
||||||
}
|
|
||||||
auto pool_dst_memory = user_dst_memory;
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
|
||||||
}
|
|
||||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
|
||||||
{DNNL_ARG_DST, pool_dst_memory}});
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
|
||||||
}
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
if (!isNCDHW) {
|
|
||||||
delete input;
|
|
||||||
delete output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_CHECK(avgpool3dnew, ENGINE_CPU) {
|
PLATFORM_CHECK(avgpool3dnew, ENGINE_CPU) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
|
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
|
||||||
|
const int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
|
const int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
const int kW = INT_ARG(2); // filter(kernel) width
|
||||||
|
const int sD = INT_ARG(3); // strides depth
|
||||||
|
const int sH = INT_ARG(4); // strides height
|
||||||
|
const int sW = INT_ARG(5); // strides width
|
||||||
|
int pD = INT_ARG(6); // paddings depth
|
||||||
|
int pH = INT_ARG(7); // paddings height
|
||||||
|
int pW = INT_ARG(8); // paddings width
|
||||||
|
const int dD = INT_ARG(9); // dilations depth
|
||||||
|
const int dH = INT_ARG(10); // dilations height
|
||||||
|
const int dW = INT_ARG(11); // dilations width
|
||||||
|
const int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
|
const int extraParam0 = INT_ARG(13); // define what divisor to use while averaging
|
||||||
|
const int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 0-NCDHW, 1-NDHWC
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "AVGPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "AVGPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||||
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "AVGPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
|
||||||
|
if(paddingMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
|
auto mode = (extraParam0 == 0) ? algorithm::pooling_avg_exclude_padding : algorithm::pooling_avg_include_padding;
|
||||||
|
|
||||||
|
mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, mode);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,154 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author raver119@gmail.com
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <ops/declarable/PlatformHelper.h>
|
|
||||||
#include <ops/declarable/OpRegistrator.h>
|
|
||||||
#include <platform_boilerplate.h>
|
|
||||||
|
|
||||||
#include <helpers/MKLDNNStream.h>
|
|
||||||
#include "mkldnnUtils.h"
|
|
||||||
#include <ops/declarable/helpers/convolutions.h>
|
|
||||||
|
|
||||||
using namespace dnnl;
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
namespace ops {
|
|
||||||
namespace platforms {
|
|
||||||
PLATFORM_IMPL(avgpool3dnew_bp, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
|
||||||
auto gradO = INPUT_VARIABLE(
|
|
||||||
1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
|
||||||
|
|
||||||
const int kD = INT_ARG(0); // filter(kernel) depth
|
|
||||||
const int kH = INT_ARG(1); // filter(kernel) height
|
|
||||||
const int kW = INT_ARG(2); // filter(kernel) width
|
|
||||||
const int sD = INT_ARG(3); // strides depth
|
|
||||||
const int sH = INT_ARG(4); // strides height
|
|
||||||
const int sW = INT_ARG(5); // strides width
|
|
||||||
int pD = INT_ARG(6); // paddings depth
|
|
||||||
int pH = INT_ARG(7); // paddings height
|
|
||||||
int pW = INT_ARG(8); // paddings width
|
|
||||||
const int dD = INT_ARG(9); // dilations depth
|
|
||||||
const int dH = INT_ARG(10); // dilations height
|
|
||||||
const int dW = INT_ARG(11); // dilations width
|
|
||||||
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
|
||||||
int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
|
||||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
|
||||||
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
|
||||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
|
||||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
|
||||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
|
||||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
|
||||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
|
||||||
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
|
||||||
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
|
||||||
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
|
||||||
|
|
||||||
if (!isNCDHW) {
|
|
||||||
input = new NDArray(input->permute(
|
|
||||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
gradI = new NDArray(gradI->permute(
|
|
||||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
gradO = new NDArray(gradO->permute(
|
|
||||||
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
auto poolingMode = PoolingType::AVG_POOL;
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
|
||||||
extraParam0, true,
|
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
|
|
||||||
algorithm,
|
|
||||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
|
||||||
&user_diff_src_md, &user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
if (input->buffer() == nullptr) {
|
|
||||||
pool_src_md = pool_diff_src_md;
|
|
||||||
user_src_md = user_diff_src_md;
|
|
||||||
}
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides,
|
|
||||||
pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
|
||||||
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
|
||||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
|
||||||
auto poolB_src_memory = userB_src_memory;
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
|
||||||
}
|
|
||||||
auto poolB_dst_memory = userB_dst_memory;
|
|
||||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
|
||||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
|
||||||
}
|
|
||||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
|
||||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
|
||||||
}
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
if (!isNCDHW) {
|
|
||||||
delete input;
|
|
||||||
delete gradI;
|
|
||||||
delete gradO;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
PLATFORM_CHECK(avgpool3dnew_bp, ENGINE_CPU) {
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -37,12 +37,12 @@ namespace platforms {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, const float epsilon, NDArray* z) {
|
static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* weights, NDArray* z,
|
||||||
|
const float epsilon, const bool isNCHW) {
|
||||||
|
|
||||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
|
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
|
||||||
// also it gives wrong results for formats nhwc and ndhwc
|
|
||||||
|
|
||||||
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
|
// x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
|
||||||
// mean -> 1D [c]
|
// mean -> 1D [c]
|
||||||
// variance -> 1D [c]
|
// variance -> 1D [c]
|
||||||
// weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta
|
// weights 2D [2, c], weights({0,1, 0,0}) contains gamma and weights({1,2, 0,0}) contains beta
|
||||||
|
@ -50,8 +50,6 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
||||||
|
|
||||||
const int xRank = x->rankOf();
|
const int xRank = x->rankOf();
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
|
|
||||||
// input type
|
// input type
|
||||||
dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
@ -63,17 +61,28 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
||||||
dnnl::memory::dims dims;
|
dnnl::memory::dims dims;
|
||||||
dnnl::memory::format_tag format;
|
dnnl::memory::format_tag format;
|
||||||
|
|
||||||
|
const int indHW = isNCHW ? 2 : 1;
|
||||||
|
const int bS = x->sizeAt(0);
|
||||||
|
const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1);
|
||||||
|
|
||||||
|
int iD, iH, iW;
|
||||||
|
|
||||||
if(xRank == 2) {
|
if(xRank == 2) {
|
||||||
dims = {x->sizeAt(0), x->sizeAt(1)};
|
dims = {bS, iC};
|
||||||
format = dnnl::memory::format_tag::nc;
|
format = dnnl::memory::format_tag::nc;
|
||||||
}
|
}
|
||||||
else if(xRank == 4) {
|
else if(xRank == 4) {
|
||||||
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
|
iH = x->sizeAt(indHW);
|
||||||
format = dnnl::memory::format_tag::nchw;
|
iW = x->sizeAt(indHW + 1);
|
||||||
|
dims = {bS, iC, iH, iW};
|
||||||
|
format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
}
|
}
|
||||||
else { // xRank = 5
|
else { // xRank = 5
|
||||||
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
|
iD = x->sizeAt(indHW);
|
||||||
format = dnnl::memory::format_tag::ncdhw;
|
iH = x->sizeAt(indHW + 1);
|
||||||
|
iW = x->sizeAt(indHW + 2);
|
||||||
|
dims = {bS, iC, iD, iH, iW};
|
||||||
|
format = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
}
|
}
|
||||||
|
|
||||||
// memory descriptors for arrays
|
// memory descriptors for arrays
|
||||||
|
@ -81,29 +90,34 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
||||||
// x
|
// x
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
|
if(x->ews() != 1 || x->ordering() != 'c') {
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
|
x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0);
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
|
x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1);
|
||||||
if(xRank > 2) {
|
if(xRank > 2) {
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
|
x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2);
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3];
|
x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3);
|
||||||
}
|
}
|
||||||
if(xRank > 4)
|
if(xRank > 4)
|
||||||
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
|
x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
// z, output
|
// z, output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
|
if(z->ews() != 1 || z->ordering() != 'c') {
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
z_user_md.data.format_desc.blocking.strides[0] = z->stridesOf()[0];
|
z_user_md.data.format_desc.blocking.strides[0] = z->strideAt(0);
|
||||||
z_user_md.data.format_desc.blocking.strides[1] = z->stridesOf()[1];
|
z_user_md.data.format_desc.blocking.strides[1] = z->strideAt(1);
|
||||||
if(xRank > 2) {
|
if(xRank > 2) {
|
||||||
z_user_md.data.format_desc.blocking.strides[2] = z->stridesOf()[2];
|
z_user_md.data.format_desc.blocking.strides[2] = z->strideAt(2);
|
||||||
z_user_md.data.format_desc.blocking.strides[3] = z->stridesOf()[3];
|
z_user_md.data.format_desc.blocking.strides[3] = z->strideAt(3);
|
||||||
}
|
}
|
||||||
if(xRank > 4)
|
if(xRank > 4)
|
||||||
z_user_md.data.format_desc.blocking.strides[4] = z->stridesOf()[4];
|
z_user_md.data.format_desc.blocking.strides[4] = z->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
// batchnorm forward description
|
// batchnorm forward description
|
||||||
dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
|
dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
|
||||||
|
@ -162,12 +176,11 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights,
|
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray* dLdO, const NDArray* weights,
|
||||||
const float epsilon, NDArray* dLdI, NDArray* dLdW) {
|
NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) {
|
||||||
|
|
||||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any)
|
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
|
||||||
// also it gives wrong results for formats nhwc and ndhwc
|
|
||||||
|
|
||||||
// x -> 2D:nc, 4D:nchw, 5D:ncdhw
|
// x -> 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
|
||||||
// mean -> 1D [c]
|
// mean -> 1D [c]
|
||||||
// variance -> 1D [c]
|
// variance -> 1D [c]
|
||||||
// dLdO - same shape as x
|
// dLdO - same shape as x
|
||||||
|
@ -177,8 +190,6 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
|
|
||||||
const int xRank = x->rankOf();
|
const int xRank = x->rankOf();
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
|
|
||||||
// input type
|
// input type
|
||||||
dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
dnnl::memory::data_type type = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
@ -190,17 +201,28 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
dnnl::memory::dims dims;
|
dnnl::memory::dims dims;
|
||||||
dnnl::memory::format_tag format;
|
dnnl::memory::format_tag format;
|
||||||
|
|
||||||
|
const int indHW = isNCHW ? 2 : 1;
|
||||||
|
const int bS = x->sizeAt(0);
|
||||||
|
const int iC = isNCHW ? x->sizeAt(1) : x->sizeAt(-1);
|
||||||
|
|
||||||
|
int iD, iH, iW;
|
||||||
|
|
||||||
if(xRank == 2) {
|
if(xRank == 2) {
|
||||||
dims = {x->sizeAt(0), x->sizeAt(1)};
|
dims = {bS, iC};
|
||||||
format = dnnl::memory::format_tag::nc;
|
format = dnnl::memory::format_tag::nc;
|
||||||
}
|
}
|
||||||
else if(xRank == 4) {
|
else if(xRank == 4) {
|
||||||
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3)};
|
iH = x->sizeAt(indHW);
|
||||||
format = dnnl::memory::format_tag::nchw;
|
iW = x->sizeAt(indHW + 1);
|
||||||
|
dims = {bS, iC, iH, iW};
|
||||||
|
format = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
}
|
}
|
||||||
else { // xRank = 5
|
else { // xRank = 5
|
||||||
dims = {x->sizeAt(0), x->sizeAt(1), x->sizeAt(2), x->sizeAt(3), x->sizeAt(4)};
|
iD = x->sizeAt(indHW);
|
||||||
format = dnnl::memory::format_tag::ncdhw;
|
iH = x->sizeAt(indHW + 1);
|
||||||
|
iW = x->sizeAt(indHW + 2);
|
||||||
|
dims = {bS, iC, iD, iH, iW};
|
||||||
|
format = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
}
|
}
|
||||||
|
|
||||||
// memory descriptors for arrays
|
// memory descriptors for arrays
|
||||||
|
@ -208,41 +230,49 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
// x
|
// x
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
|
if(x->ews() != 1 || x->ordering() != 'c') {
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
|
x_user_md.data.format_desc.blocking.strides[0] = x->strideAt(0);
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
|
x_user_md.data.format_desc.blocking.strides[1] = x->strideAt(1);
|
||||||
if(xRank > 2) {
|
if(xRank > 2) {
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
|
x_user_md.data.format_desc.blocking.strides[2] = x->strideAt(2);
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = x->stridesOf()[3];
|
x_user_md.data.format_desc.blocking.strides[3] = x->strideAt(3);
|
||||||
}
|
}
|
||||||
if(xRank > 4)
|
if(xRank > 4)
|
||||||
x_user_md.data.format_desc.blocking.strides[4] = x->stridesOf()[4];
|
x_user_md.data.format_desc.blocking.strides[4] = x->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
// dLdO
|
// dLdO
|
||||||
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
|
if(dLdO->ews() != 1 || dLdO->ordering() != 'c') {
|
||||||
dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
dLdO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->stridesOf()[0];
|
dLdO_user_md.data.format_desc.blocking.strides[0] = dLdO->strideAt(0);
|
||||||
dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->stridesOf()[1];
|
dLdO_user_md.data.format_desc.blocking.strides[1] = dLdO->strideAt(1);
|
||||||
if(xRank > 2) {
|
if(xRank > 2) {
|
||||||
dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->stridesOf()[2];
|
dLdO_user_md.data.format_desc.blocking.strides[2] = dLdO->strideAt(2);
|
||||||
dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->stridesOf()[3];
|
dLdO_user_md.data.format_desc.blocking.strides[3] = dLdO->strideAt(3);
|
||||||
}
|
}
|
||||||
if(xRank > 4)
|
if(xRank > 4)
|
||||||
dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->stridesOf()[4];
|
dLdO_user_md.data.format_desc.blocking.strides[4] = dLdO->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
// dLdI
|
// dLdI
|
||||||
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
|
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
|
||||||
|
if(dLdI->ews() != 1 || dLdI->ordering() != 'c') {
|
||||||
dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
dLdI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->stridesOf()[0];
|
dLdI_user_md.data.format_desc.blocking.strides[0] = dLdI->strideAt(0);
|
||||||
dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->stridesOf()[1];
|
dLdI_user_md.data.format_desc.blocking.strides[1] = dLdI->strideAt(1);
|
||||||
if(xRank > 2) {
|
if(xRank > 2) {
|
||||||
dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->stridesOf()[2];
|
dLdI_user_md.data.format_desc.blocking.strides[2] = dLdI->strideAt(2);
|
||||||
dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->stridesOf()[3];
|
dLdI_user_md.data.format_desc.blocking.strides[3] = dLdI->strideAt(3);
|
||||||
}
|
}
|
||||||
if(xRank > 4)
|
if(xRank > 4)
|
||||||
dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->stridesOf()[4];
|
dLdI_user_md.data.format_desc.blocking.strides[4] = dLdI->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
// batchnorm forward description
|
// batchnorm forward description
|
||||||
dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
|
dnnl::batch_normalization_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, epsilon, flags);
|
||||||
|
@ -331,7 +361,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
// dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m))
|
// dLdI = dfdm / N + (2/N) * dfdv * (dvdm/2 + (x - m))
|
||||||
// dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) )
|
// dLdI = gamma * ( stdInv * -g_sum/N + (2/N) * dfdv * (dvdm/2 + (x - m)) )
|
||||||
|
|
||||||
std::vector<int> axes = {1};
|
std::vector<int> axes = isNCHW ? std::vector<int>{1} : std::vector<int>{xRank - 1};
|
||||||
const auto excludedAxes = ShapeUtils::evalDimsToExclude(x->rankOf(), axes);
|
const auto excludedAxes = ShapeUtils::evalDimsToExclude(x->rankOf(), axes);
|
||||||
|
|
||||||
// inversed batch size 1 / N
|
// inversed batch size 1 / N
|
||||||
|
@ -377,7 +407,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
||||||
|
|
||||||
PLATFORM_IMPL(batchnorm, ENGINE_CPU) {
|
PLATFORM_IMPL(batchnorm, ENGINE_CPU) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
|
||||||
auto mean = INPUT_VARIABLE(1); // [c]
|
auto mean = INPUT_VARIABLE(1); // [c]
|
||||||
auto variance = INPUT_VARIABLE(2); // [c]
|
auto variance = INPUT_VARIABLE(2); // [c]
|
||||||
NDArray* gamma = nullptr; // [c]
|
NDArray* gamma = nullptr; // [c]
|
||||||
|
@ -436,31 +466,19 @@ PLATFORM_IMPL(batchnorm, ENGINE_CPU) {
|
||||||
(*weights)({1,2, 0,0}).assign(0);
|
(*weights)({1,2, 0,0}).assign(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc
|
const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
|
||||||
std::vector<int> permut = inRank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
|
|
||||||
input = new NDArray(input->permute(permut));
|
|
||||||
output = new NDArray(output->permute(permut));
|
|
||||||
}
|
|
||||||
|
|
||||||
batchnormMKLDNN(input, mean, variance, weights, epsilon, output);
|
batchnormMKLDNN(input, mean, variance, weights, output, epsilon, isNCHW);
|
||||||
|
|
||||||
delete weights;
|
delete weights;
|
||||||
|
|
||||||
if(axes[0] == inRank - 1 && inRank > 2) {
|
|
||||||
delete input;
|
|
||||||
delete output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_CHECK(batchnorm, ENGINE_CPU) {
|
PLATFORM_CHECK(batchnorm, ENGINE_CPU) {
|
||||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
|
||||||
// if (::optimalLevel() < 2)
|
|
||||||
// return false;
|
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
auto input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
|
||||||
auto mean = INPUT_VARIABLE(1); // [c]
|
auto mean = INPUT_VARIABLE(1); // [c]
|
||||||
auto variance = INPUT_VARIABLE(2); // [c]
|
auto variance = INPUT_VARIABLE(2); // [c]
|
||||||
NDArray* gamma = nullptr; // [c]
|
NDArray* gamma = nullptr; // [c]
|
||||||
|
@ -634,7 +652,7 @@ PLATFORM_CHECK(batchnorm, ENGINE_CPU) {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
|
PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw/nhwc, 5D:ncdhw/ndhwc
|
||||||
NDArray* mean = INPUT_VARIABLE(1); // [c]
|
NDArray* mean = INPUT_VARIABLE(1); // [c]
|
||||||
NDArray* variance = INPUT_VARIABLE(2); // [c]
|
NDArray* variance = INPUT_VARIABLE(2); // [c]
|
||||||
NDArray* gamma = nullptr; // [c]
|
NDArray* gamma = nullptr; // [c]
|
||||||
|
@ -702,15 +720,9 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
|
||||||
(*weights)({1,2, 0,0}).assign(0);
|
(*weights)({1,2, 0,0}).assign(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
|
||||||
|
|
||||||
if(axes[0] == inRank - 1 && inRank > 2) { // if nhwc or ndhwc
|
batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, dLdI, dLdW, epsilon, isNCHW);
|
||||||
std::vector<int> permut = inRank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
|
|
||||||
input = new NDArray(input->permute(permut));
|
|
||||||
dLdO = new NDArray(dLdO->permute(permut));
|
|
||||||
dLdI = new NDArray(dLdI->permute(permut));
|
|
||||||
}
|
|
||||||
|
|
||||||
batchnormBackPropMKLDNN(input, mean, variance, dLdO, weights, epsilon, dLdI, dLdW);
|
|
||||||
|
|
||||||
*dLdM = 0;
|
*dLdM = 0;
|
||||||
*dLdV = 0;
|
*dLdV = 0;
|
||||||
|
@ -725,17 +737,12 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
|
||||||
delete dLdW;
|
delete dLdW;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(axes[0] == inRank - 1 && inRank > 2) {
|
|
||||||
delete input;
|
|
||||||
delete dLdO;
|
|
||||||
delete dLdI;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_CHECK(batchnorm_bp, ENGINE_CPU) {
|
PLATFORM_CHECK(batchnorm_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
NDArray* input = INPUT_VARIABLE(0); // 2D:nc, 4D:nchw, 5D:ncdhw
|
||||||
NDArray* mean = INPUT_VARIABLE(1); // [c]
|
NDArray* mean = INPUT_VARIABLE(1); // [c]
|
||||||
NDArray* variance = INPUT_VARIABLE(2); // [c]
|
NDArray* variance = INPUT_VARIABLE(2); // [c]
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
//
|
//
|
||||||
// @author saudet
|
// @author saudet
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/PlatformHelper.h>
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
@ -33,6 +34,298 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
|
const NDArray *bias, NDArray *output,
|
||||||
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
|
const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
|
// weights [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW]
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||||
|
|
||||||
|
dnnl::memory::dims strides = { sH, sW };
|
||||||
|
dnnl::memory::dims padding = { pH, pW };
|
||||||
|
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||||
|
dnnl::memory::dims dilation = { dH-1, dW-1};
|
||||||
|
|
||||||
|
auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||||
|
|
||||||
|
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||||
|
dnnl::memory::dims wDims = {oC, iC, kH, kW};
|
||||||
|
dnnl::memory::dims zDims = {bS, oC, oH, oW};
|
||||||
|
|
||||||
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
|
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||||
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// weights
|
||||||
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||||
|
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||||
|
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||||
|
|
||||||
|
// bias
|
||||||
|
dnnl::memory::desc b_mkl_md;
|
||||||
|
if(bias != nullptr)
|
||||||
|
b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x);
|
||||||
|
|
||||||
|
// output
|
||||||
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||||
|
if(output->ews() != 1 || output->ordering() != 'c') {
|
||||||
|
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// operation primitive description
|
||||||
|
dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||||
|
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// weights
|
||||||
|
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||||
|
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||||
|
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||||
|
if (wReorder)
|
||||||
|
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||||
|
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
|
||||||
|
// bias
|
||||||
|
if(bias != nullptr) {
|
||||||
|
auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
|
||||||
|
args[DNNL_ARG_BIAS] = b_mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// output
|
||||||
|
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
|
||||||
|
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||||
|
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||||
|
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||||
|
|
||||||
|
// run calculations
|
||||||
|
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder outputs if necessary
|
||||||
|
if (zReorder)
|
||||||
|
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
|
||||||
|
NDArray *gradI, NDArray *gradW, NDArray *gradB,
|
||||||
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
|
const int paddingMode, const int isNCHW) {
|
||||||
|
|
||||||
|
// weights/gradW [kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kH, kW]
|
||||||
|
|
||||||
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||||
|
|
||||||
|
dnnl::memory::dims strides = { sH, sW };
|
||||||
|
dnnl::memory::dims padding = { pH, pW };
|
||||||
|
dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||||
|
dnnl::memory::dims dilation = { dH-1, dW-1};
|
||||||
|
|
||||||
|
auto xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||||
|
|
||||||
|
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||||
|
dnnl::memory::dims wDims = {oC, iC, kH, kW};
|
||||||
|
dnnl::memory::dims zDims = {bS, oC, oH, oW};
|
||||||
|
|
||||||
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
|
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||||
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// weights
|
||||||
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||||
|
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||||
|
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||||
|
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||||
|
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
|
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||||
|
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// gradW
|
||||||
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||||
|
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(2);
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
|
||||||
|
|
||||||
|
// gradB
|
||||||
|
dnnl::memory::desc gradB_mkl_md;
|
||||||
|
if(gradB != nullptr)
|
||||||
|
gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// forward primitive description
|
||||||
|
dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||||
|
|
||||||
|
// backward data primitive description
|
||||||
|
dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// backward weights primitive description
|
||||||
|
dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||||
|
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// weights
|
||||||
|
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||||
|
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||||
|
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||||
|
if (wReorder)
|
||||||
|
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||||
|
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
|
const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
|
const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
|
auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
|
auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
|
if (gradOReorderW)
|
||||||
|
dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
|
||||||
|
if (gradOReorderD)
|
||||||
|
dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
|
||||||
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||||
|
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||||
|
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||||
|
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||||
|
|
||||||
|
// gradW
|
||||||
|
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
|
||||||
|
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||||
|
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||||
|
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||||
|
|
||||||
|
// gradB
|
||||||
|
if(gradB != nullptr) {
|
||||||
|
auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
|
||||||
|
args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// run backward data calculations
|
||||||
|
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
if(gradOReorderW || gradOReorderD)
|
||||||
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
|
||||||
|
|
||||||
|
// run backward weights calculations
|
||||||
|
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder gradI if necessary
|
||||||
|
if (gradIReorder)
|
||||||
|
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||||
|
if (gradWReorder)
|
||||||
|
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
||||||
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
||||||
|
@ -46,37 +339,37 @@ static void conv2dMKLDNN(nd4j::graph::Context &block, const NDArray *input, cons
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
dnnl_memory_desc_t empty;
|
||||||
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(empty);
|
dnnl::memory::desc x_mkl_md(empty), w_mkl_md(empty), b_mkl_md(empty), z_mkl_md(empty);
|
||||||
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
dnnl::memory::desc x_user_md(empty), w_user_md(empty), b_user_md(empty), z_user_md(empty);
|
||||||
|
|
||||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
dnnl::memory::dims strides, padding, padding_r, dilation;
|
||||||
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
||||||
bias, output,
|
bias, output,
|
||||||
&conv_src_md, nullptr, &conv_weights_md, nullptr,
|
&x_mkl_md, nullptr, &w_mkl_md, nullptr,
|
||||||
&conv_bias_md, &conv_dst_md,
|
&b_mkl_md, &z_mkl_md,
|
||||||
&user_src_md, nullptr, &user_weights_md, nullptr,
|
&x_user_md, nullptr, &w_user_md, nullptr,
|
||||||
&user_bias_md, &user_dst_md,
|
&b_user_md, &z_user_md,
|
||||||
conv_strides, conv_padding, conv_padding_r, conv_dilation);
|
strides, padding, padding_r, dilation);
|
||||||
|
|
||||||
auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward,
|
auto conv_desc = bias != nullptr ? convolution_forward::desc(prop_kind::forward,
|
||||||
algorithm::convolution_auto, conv_src_md,
|
algorithm::convolution_auto, x_mkl_md,
|
||||||
conv_weights_md, conv_bias_md,
|
w_mkl_md, b_mkl_md,
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
z_mkl_md, strides, dilation, padding,
|
||||||
conv_padding_r)
|
padding_r)
|
||||||
: convolution_forward::desc(prop_kind::forward,
|
: convolution_forward::desc(prop_kind::forward,
|
||||||
algorithm::convolution_auto, conv_src_md,
|
algorithm::convolution_auto, x_mkl_md,
|
||||||
conv_weights_md,
|
w_mkl_md,
|
||||||
conv_dst_md, conv_strides, conv_dilation, conv_padding,
|
z_mkl_md, strides, dilation, padding,
|
||||||
conv_padding_r);
|
padding_r);
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
dnnl::stream stream(engine);
|
dnnl::stream stream(engine);
|
||||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc, engine);
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, const_cast<NDArray *>(input)->buffer());
|
auto user_src_memory = dnnl::memory(x_user_md, engine, const_cast<NDArray *>(input)->buffer());
|
||||||
auto user_weights_memory = dnnl::memory(user_weights_md, engine,
|
auto user_weights_memory = dnnl::memory(w_user_md, engine,
|
||||||
const_cast<NDArray *>(weights)->buffer());
|
const_cast<NDArray *>(weights)->buffer());
|
||||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
auto user_dst_memory = dnnl::memory(z_user_md, engine, output->buffer());
|
||||||
auto conv_src_memory = user_src_memory;
|
auto conv_src_memory = user_src_memory;
|
||||||
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
if (conv_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
||||||
conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
|
conv_src_memory = dnnl::memory(conv_prim_desc.src_desc(), engine);
|
||||||
|
@ -239,8 +532,11 @@ static void conv2dBpMKLDNN(nd4j::graph::Context &block,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
@ -259,16 +555,24 @@ PLATFORM_IMPL(conv2d, ENGINE_CPU) {
|
||||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
||||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
||||||
|
|
||||||
conv2dMKLDNN(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||||
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
if (bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
conv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
PLATFORM_CHECK(conv2d, ENGINE_CPU) {
|
|
||||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
|
||||||
if (::optimalLevel() < 2)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
|
PLATFORM_CHECK(conv2d, ENGINE_CPU) {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto weights = INPUT_VARIABLE(1);
|
auto weights = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
@ -300,16 +604,30 @@ PLATFORM_IMPL(conv2d_bp, ENGINE_CPU) {
|
||||||
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of input array must be equal to 4, but got %i instead !",input->rankOf());
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
REQUIRE_TRUE(weights->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of weights array must be equal to 4, but got %i instead !",weights->rankOf());
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
REQUIRE_TRUE(gradO->rankOf() == 4, 0,"CUSTOM CONV2D_BP OP: rank of output's gradients (next epsilon) array must be equal to 4, but got %i instead !",gradO->rankOf());
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
conv2dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
int trueoH, trueoW; // true output height, width
|
||||||
|
ConvolutionUtils::calcOutSizePool2D(trueoH, trueoW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, paddingMode);
|
||||||
|
|
||||||
|
if(paddingMode) // SAME
|
||||||
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||||
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||||
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CONV2D_BP MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
if(bias)
|
||||||
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CONV2D_BP MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
|
conv2dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
PLATFORM_CHECK(conv2d_bp, ENGINE_CPU) {
|
PLATFORM_CHECK(conv2d_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
|
@ -33,6 +33,314 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
||||||
|
const NDArray *bias, NDArray *output,
|
||||||
|
const int kD, const int kH, const int kW,
|
||||||
|
const int sD, const int sH, const int sW,
|
||||||
|
const int pD, const int pH, const int pW,
|
||||||
|
const int dD, const int dH, const int dW,
|
||||||
|
const int paddingMode, const int isNCDHW) {
|
||||||
|
|
||||||
|
// weights [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW]
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
// const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||||
|
|
||||||
|
dnnl::memory::dims strides = {sD, sH, sW};
|
||||||
|
dnnl::memory::dims padding = {pD, pH, pW};
|
||||||
|
// dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||||
|
dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW};
|
||||||
|
dnnl::memory::dims dilation = {dD-1, dH-1, dW-1};
|
||||||
|
|
||||||
|
auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||||
|
|
||||||
|
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||||
|
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||||
|
dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
|
||||||
|
|
||||||
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
|
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||||
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// weights
|
||||||
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||||
|
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||||
|
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
|
||||||
|
|
||||||
|
// bias
|
||||||
|
dnnl::memory::desc b_mkl_md;
|
||||||
|
if(bias != nullptr)
|
||||||
|
b_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x);
|
||||||
|
|
||||||
|
// output
|
||||||
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||||
|
if(output->ews() != 1 || output->ordering() != 'c') {
|
||||||
|
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// operation primitive description
|
||||||
|
dnnl::convolution_forward::desc op_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, b_mkl_md, z_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
dnnl::convolution_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||||
|
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// weights
|
||||||
|
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||||
|
const bool wReorder = op_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||||
|
auto w_mkl_mem = wReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||||
|
if (wReorder)
|
||||||
|
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||||
|
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
|
||||||
|
// bias
|
||||||
|
if(bias != nullptr) {
|
||||||
|
auto b_mkl_mem = dnnl::memory(b_mkl_md, engine, bias->getBuffer());
|
||||||
|
args[DNNL_ARG_BIAS] = b_mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// output
|
||||||
|
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
|
||||||
|
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||||
|
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||||
|
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||||
|
|
||||||
|
// run calculations
|
||||||
|
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder outputs if necessary
|
||||||
|
if (zReorder)
|
||||||
|
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const NDArray *bias, const NDArray *gradO,
|
||||||
|
NDArray *gradI, NDArray *gradW, NDArray *gradB,
|
||||||
|
const int kD, const int kH, const int kW,
|
||||||
|
const int sD, const int sH, const int sW,
|
||||||
|
const int pD, const int pH, const int pW,
|
||||||
|
const int dD, const int dH, const int dW,
|
||||||
|
const int paddingMode, const int isNCDHW) {
|
||||||
|
|
||||||
|
// weights/gradW [kD, kH, kW, iC, oC], we'll perform permutation since mkl support [oC, iC, kD, kH, kW]
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
|
// const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||||
|
|
||||||
|
dnnl::memory::dims strides = {sD, sH, sW};
|
||||||
|
dnnl::memory::dims padding = {pD, pH, pW};
|
||||||
|
// dnnl::memory::dims padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||||
|
dnnl::memory::dims padding_r = {(oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW};
|
||||||
|
dnnl::memory::dims dilation = {dD-1, dH-1, dW-1};
|
||||||
|
|
||||||
|
auto xzFrmat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||||
|
|
||||||
|
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||||
|
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||||
|
dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
|
||||||
|
|
||||||
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
|
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||||
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// weights
|
||||||
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||||
|
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||||
|
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||||
|
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||||
|
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||||
|
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
|
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||||
|
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
|
// gradW
|
||||||
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormat);
|
||||||
|
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(4); // permute [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3);
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
|
||||||
|
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2);
|
||||||
|
|
||||||
|
// gradB
|
||||||
|
dnnl::memory::desc gradB_mkl_md;
|
||||||
|
if(gradB != nullptr)
|
||||||
|
gradB_mkl_md = dnnl::memory::desc({oC}, type, dnnl::memory::format_tag::x);
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// forward primitive description
|
||||||
|
dnnl::convolution_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::convolution_auto, x_mkl_md, w_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
dnnl::convolution_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||||
|
|
||||||
|
// backward data primitive description
|
||||||
|
dnnl::convolution_backward_data::desc op_data_bp_desc(dnnl::algorithm::convolution_auto, gradI_mkl_md, w_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
dnnl::convolution_backward_data::primitive_desc op_data_bp_prim_desc(op_data_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// backward weights primitive description
|
||||||
|
dnnl::convolution_backward_weights::desc op_weights_bp_desc(dnnl::algorithm::convolution_auto, x_mkl_md, gradW_mkl_md, gradB_mkl_md, gradO_mkl_md, strides, dilation, padding, padding_r);
|
||||||
|
dnnl::convolution_backward_weights::primitive_desc op_weights_bp_prim_desc(op_weights_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||||
|
const bool xReorder = op_weights_bp_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? dnnl::memory(op_weights_bp_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// weights
|
||||||
|
auto w_user_mem = dnnl::memory(w_user_md, engine, weights->getBuffer());
|
||||||
|
const bool wReorder = op_data_bp_prim_desc.weights_desc() != w_user_mem.get_desc();
|
||||||
|
auto w_mkl_mem = wReorder ? dnnl::memory(op_data_bp_prim_desc.weights_desc(), engine) : w_user_mem;
|
||||||
|
if (wReorder)
|
||||||
|
dnnl::reorder(w_user_mem, w_mkl_mem).execute(stream, w_user_mem, w_mkl_mem);
|
||||||
|
args[DNNL_ARG_WEIGHTS] = w_mkl_mem;
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
|
const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
|
const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
|
auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
|
auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
|
if (gradOReorderW)
|
||||||
|
dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
|
||||||
|
if (gradOReorderD)
|
||||||
|
dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
|
||||||
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||||
|
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||||
|
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||||
|
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||||
|
|
||||||
|
// gradW
|
||||||
|
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->getBuffer());
|
||||||
|
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||||
|
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||||
|
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||||
|
|
||||||
|
// gradB
|
||||||
|
if(gradB != nullptr) {
|
||||||
|
auto gradB_mkl_mem = dnnl::memory(gradB_mkl_md, engine, gradB->getBuffer());
|
||||||
|
args[DNNL_ARG_DIFF_BIAS] = gradB_mkl_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
// run backward data calculations
|
||||||
|
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
if(gradOReorderW || gradOReorderD)
|
||||||
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
|
||||||
|
|
||||||
|
// run backward weights calculations
|
||||||
|
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder gradI if necessary
|
||||||
|
if (gradIReorder)
|
||||||
|
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||||
|
if (gradWReorder)
|
||||||
|
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
|
||||||
|
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
static void conv3dMKLDNN(nd4j::graph::Context &block,
|
static void conv3dMKLDNN(nd4j::graph::Context &block,
|
||||||
const NDArray *input, const NDArray *weights, const NDArray *bias,
|
const NDArray *input, const NDArray *weights, const NDArray *bias,
|
||||||
|
@ -225,6 +533,7 @@ static void conv3dBpMKLDNN(nd4j::graph::Context &block,
|
||||||
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory);
|
reorder(convI_src_memory, userI_src_memory).execute(stream, convI_src_memory, userI_src_memory);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
||||||
|
@ -256,24 +565,20 @@ PLATFORM_IMPL(conv3dnew, ENGINE_CPU) {
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
|
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
if (bias)
|
if (bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D MKLDNN OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
if (paddingMode) // SAME
|
if (paddingMode) // SAME
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
conv3dMKLDNN(block, input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
conv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
|
PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
|
||||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
|
||||||
if (::optimalLevel() < 2)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
@ -284,6 +589,7 @@ PLATFORM_CHECK(conv3dnew, ENGINE_CPU) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
@ -322,20 +628,19 @@ PLATFORM_IMPL(conv3dnew_bp, ENGINE_CPU) {
|
||||||
int trueoD, trueoH, trueoW; // true output depth/height/width
|
int trueoD, trueoH, trueoW; // true output depth/height/width
|
||||||
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
ConvolutionUtils::calcOutSizePool3D(trueoD, trueoH, trueoW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, paddingMode);
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx( {bS, oC, trueoD, trueoH, trueoW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2});
|
||||||
std::string expectedWeightsShape = ShapeUtils::shapeAsString({kD, kH, kW, iC, oC});
|
std::vector<Nd4jLong> expectedWeightsShape = {kD, kH, kW, iC, oC};
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of output gradients (next epsilon) array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
REQUIRE_TRUE(expectedWeightsShape == ShapeUtils::shapeAsString(weights), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", expectedWeightsShape.c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM CONV3D_BP OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
if (bias)
|
if (bias)
|
||||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM CONV3D_BP OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||||
|
|
||||||
conv3dBpMKLDNN(block, input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
conv3dBpMKLDNN(input, weights, bias, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, paddingMode, isNCDHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) {
|
PLATFORM_CHECK(conv3dnew_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kD, kH, kW, iC, oC] always
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
|
@ -34,17 +34,13 @@ namespace platforms {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
||||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
const int paddingMode) {
|
const int paddingMode, const bool isNCHW) {
|
||||||
|
|
||||||
// input [bS, iC, iH, iW] nchw, mkl doesn't support format nhwc
|
// weights [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation
|
||||||
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
|
|
||||||
// bias [oC], may be nullptr
|
|
||||||
|
|
||||||
// output [bS, oC, oH, oW] nchw, mkl doesn't support format nhwc
|
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||||
|
|
||||||
dnnl::memory::dims strides = { sH, sW };
|
dnnl::memory::dims strides = { sH, sW };
|
||||||
dnnl::memory::dims padding = { pH, pW };
|
dnnl::memory::dims padding = { pH, pW };
|
||||||
|
@ -80,8 +76,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
else
|
else
|
||||||
zType = dnnl::memory::data_type::s32;
|
zType = dnnl::memory::data_type::s32;
|
||||||
|
|
||||||
|
dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
|
||||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||||
|
|
||||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||||
|
@ -93,20 +88,22 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||||
|
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
dnnl::memory::desc b_mkl_md;
|
dnnl::memory::desc b_mkl_md;
|
||||||
|
@ -116,11 +113,13 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
||||||
|
if(output->ews() != 1 || output->ordering() != 'c') {
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
|
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||||
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
|
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
|
||||||
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
|
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
||||||
z_user_md.data.format_desc.blocking.strides[3] = output->stridesOf()[3];
|
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -179,21 +178,19 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
const int paddingMode) {
|
const int paddingMode, const bool isNCHW) {
|
||||||
|
|
||||||
// input and gradI [bS, iC, iH, iW], mkl doesn't support ndhwc format
|
// weights and gradW [oC, iC, kH, kW] always, mkl doesn't support [kH, kW, oC, iC], so we'll perform permutation
|
||||||
// weights and gradW [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, oC, iC]
|
|
||||||
// gradB [oC], may be nullptr
|
|
||||||
// gradO [bS, oC, oH, oW]
|
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(true, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH);
|
||||||
|
|
||||||
dnnl::memory::dims strides = { sH, sW };
|
dnnl::memory::dims strides = { sH, sW };
|
||||||
dnnl::memory::dims padding = { pH, pW };
|
dnnl::memory::dims padding = { pH, pW };
|
||||||
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||||
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
||||||
|
|
||||||
// input type
|
// input type
|
||||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||||
// weights type
|
// weights type
|
||||||
|
@ -207,7 +204,7 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
||||||
// gradB type
|
// gradB type
|
||||||
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||||
|
|
||||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||||
|
@ -219,54 +216,59 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||||
|
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(3);
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||||
|
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
|
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
|
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
|
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
|
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||||
|
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
|
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
|
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
|
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
|
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
|
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(2); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
|
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(3);
|
||||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
|
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
|
||||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
|
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
dnnl::memory::desc gradB_mkl_md;
|
dnnl::memory::desc gradB_mkl_md;
|
||||||
if(gradB != nullptr)
|
if(gradB != nullptr)
|
||||||
gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x);
|
gradB_mkl_md = dnnl::memory::desc({oC}, gradBType, dnnl::memory::format_tag::x);
|
||||||
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
// forward primitive description
|
// forward primitive description
|
||||||
|
@ -306,11 +308,15 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
if (gradOReorder)
|
auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
if (gradOReorderW)
|
||||||
|
dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
|
||||||
|
if (gradOReorderD)
|
||||||
|
dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
|
||||||
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||||
|
@ -333,6 +339,9 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
||||||
// run backward data calculations
|
// run backward data calculations
|
||||||
dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
if(gradOReorderW || gradOReorderD)
|
||||||
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
|
||||||
|
|
||||||
// run backward weights calculations
|
// run backward weights calculations
|
||||||
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
@ -385,32 +394,12 @@ PLATFORM_IMPL(deconv2d, ENGINE_CPU) {
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||||
}
|
}
|
||||||
|
|
||||||
// mkl supports only [oC, iC, kH, kW] format for weights
|
deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||||
weights = new NDArray(weights->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
|
||||||
|
|
||||||
// mkl supports only NCHW
|
|
||||||
if(!isNCHW) {
|
|
||||||
input = new NDArray(input->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
output = new NDArray(output->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
deconv2dMKLDNN(input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
|
|
||||||
|
|
||||||
delete weights;
|
|
||||||
|
|
||||||
if(!isNCHW) {
|
|
||||||
delete input;
|
|
||||||
delete output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
PLATFORM_CHECK(deconv2d, ENGINE_CPU) {
|
PLATFORM_CHECK(deconv2d, ENGINE_CPU) {
|
||||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
|
||||||
// if (::optimalLevel() < 2)
|
|
||||||
// return false;
|
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto weights = INPUT_VARIABLE(1);
|
auto weights = INPUT_VARIABLE(1);
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
|
||||||
|
@ -481,27 +470,7 @@ PLATFORM_IMPL(deconv2d_bp, ENGINE_CPU) {
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, iH, iW, oH, oW, kH, kW, sH, sW, dH, dW);
|
||||||
}
|
}
|
||||||
|
|
||||||
// mkl supports only [oC, iC, kH, kW] for weights
|
deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||||
weights = new NDArray(weights->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
|
||||||
gradW = new NDArray(gradW->permute({2,3,0,1})); // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
|
||||||
|
|
||||||
// mkl supports NCHW format only
|
|
||||||
if(!isNCHW) {
|
|
||||||
input = new NDArray(input->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
deconv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode);
|
|
||||||
|
|
||||||
delete weights;
|
|
||||||
delete gradW;
|
|
||||||
|
|
||||||
if(!isNCHW) {
|
|
||||||
delete input;
|
|
||||||
delete gradI;
|
|
||||||
delete gradO;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,8 @@ namespace platforms {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
|
static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
|
||||||
const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW,
|
const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW,
|
||||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
|
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||||
|
const bool isNCHW) {
|
||||||
|
|
||||||
// gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
|
// gradI [bS, iH, iW, iC], mkl doesn't support ndhwc format
|
||||||
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
|
// weights [oC, iC, kH, kW] always, mkl doesn't support weights format [kH, kW, iC, oC]
|
||||||
|
@ -51,7 +52,7 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
||||||
// gradI type
|
// gradI type
|
||||||
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
dnnl::memory::data_type gradIType = gradI->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||||
|
|
||||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
dnnl::memory::format_tag xFormat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oihw;
|
||||||
|
|
||||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||||
|
@ -67,29 +68,32 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||||
|
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
|
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
|
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
|
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
|
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||||
|
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
|
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
|
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
|
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
|
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -166,9 +170,9 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
|
||||||
|
|
||||||
const int rank = gradO->rankOf();
|
const int rank = gradO->rankOf();
|
||||||
|
|
||||||
REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
REQUIRE_TRUE(weights->rankOf() == rank, 0, "CUSTOM DECONV2D_TF MKLDNN OP: rank of weights array must be equal to 4, but got %i instead !", weights->rankOf());
|
||||||
REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf());
|
REQUIRE_TRUE(gradIShape->rankOf() == 1, 0, "CUSTOM DECONV2D_TF MKLDNN OP: rank of array with output shape must be equal to 1, but got %i instead !", gradIShape->rankOf());
|
||||||
REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf());
|
REQUIRE_TRUE(gradIShape->lengthOf() == rank, 0, "CUSTOM DECONV2D_TF MKLDNN OP: length of array with output shape must be equal to 4, but got %i instead !", gradIShape->lengthOf());
|
||||||
|
|
||||||
int indIOioC, indIiH, indWoC(3), indOoH;
|
int indIOioC, indIiH, indWoC(3), indOoH;
|
||||||
if(!isNCHW) {
|
if(!isNCHW) {
|
||||||
|
@ -193,29 +197,29 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
|
||||||
|
|
||||||
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,oC,trueoH,trueoW, 0,indIOioC,indOoH,indOoH+1});
|
||||||
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
std::vector<Nd4jLong> expectedWeightsShape = {kH, kW, iC, oC};
|
||||||
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of input array, basing on array with output shape expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
REQUIRE_TRUE(weights->isSameShape(expectedWeightsShape), 0, "CUSTOM DECONV2D_TF MKLDNN OP: wrong shape of weights array, expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedWeightsShape).c_str(), ShapeUtils::shapeAsString(weights).c_str());
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
if(isSameMode) // SAME
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
// mkl supports only [oC, iC, kH, kW] for weights
|
// // mkl supports only [oC, iC, kH, kW] for weights
|
||||||
weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
// weights = new NDArray(weights->permute({3,2,0,1})); // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||||
|
|
||||||
// mkl supports NCHW format only
|
// // mkl supports NCHW format only
|
||||||
if(!isNCHW) {
|
// if(!isNCHW) {
|
||||||
gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
// gradI = new NDArray(gradI->permute({0,3,1,2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
||||||
gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
// gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||||
}
|
// }
|
||||||
|
|
||||||
deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW);
|
deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW);
|
||||||
|
|
||||||
delete weights;
|
// delete weights;
|
||||||
|
|
||||||
if(!isNCHW) {
|
// if(!isNCHW) {
|
||||||
delete gradI;
|
// delete gradI;
|
||||||
delete gradO;
|
// delete gradO;
|
||||||
}
|
// }
|
||||||
|
|
||||||
// ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
// ConvolutionUtils::conv2dBP(block, &input, weights, nullptr, gradO, gradI, nullptr, nullptr, kH,kW,sH,sW,pH,pW,dH,dW,isSameMode,isNCHW);
|
||||||
|
|
||||||
|
|
|
@ -34,17 +34,14 @@ namespace platforms {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* bias, NDArray* output,
|
||||||
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW,
|
const int kD, const int kH, const int kW, const int sD, const int sH, const int sW,
|
||||||
const int pD, const int pH, const int pW, const int dD, const int dH, const int dW) {
|
const int pD, const int pH, const int pW, const int dD, const int dH, const int dW,
|
||||||
|
const bool isNCDHW) {
|
||||||
|
|
||||||
// input [bS, iD, iH, iW, iC] ncdhw, mkl doesn't support format ndhwc
|
// weights [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation
|
||||||
// weights [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
|
|
||||||
// bias [oC], may be nullptr
|
|
||||||
|
|
||||||
// output [bS, oD, oH, oW, oC] ncdhw, mkl doesn't support format ndhwc
|
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||||
|
|
||||||
dnnl::memory::dims strides = { sD, sH, sW };
|
dnnl::memory::dims strides = { sD, sH, sW };
|
||||||
dnnl::memory::dims padding = { pD, pH, pW };
|
dnnl::memory::dims padding = { pD, pH, pW };
|
||||||
|
@ -80,8 +77,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
else
|
else
|
||||||
zType = dnnl::memory::data_type::s32;
|
zType = dnnl::memory::data_type::s32;
|
||||||
|
|
||||||
|
dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw;
|
|
||||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||||
|
|
||||||
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||||
|
@ -93,22 +89,24 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||||
|
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||||
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
|
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4);
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||||
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
|
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
|
||||||
|
|
||||||
// bias
|
// bias
|
||||||
dnnl::memory::desc b_mkl_md;
|
dnnl::memory::desc b_mkl_md;
|
||||||
|
@ -118,12 +116,14 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormat);
|
||||||
|
if(output->ews() !=1 || output->ordering() != 'c') {
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
z_user_md.data.format_desc.blocking.strides[0] = output->stridesOf()[0];
|
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||||
z_user_md.data.format_desc.blocking.strides[1] = output->stridesOf()[1];
|
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1);
|
||||||
z_user_md.data.format_desc.blocking.strides[2] = output->stridesOf()[2];
|
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
||||||
z_user_md.data.format_desc.blocking.strides[3] = output->stridesOf()[3];
|
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
||||||
z_user_md.data.format_desc.blocking.strides[4] = output->stridesOf()[4];
|
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -184,16 +184,14 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
const int kD, const int kH, const int kW,
|
const int kD, const int kH, const int kW,
|
||||||
const int sD, const int sH, const int sW,
|
const int sD, const int sH, const int sW,
|
||||||
const int pD, const int pH, const int pW,
|
const int pD, const int pH, const int pW,
|
||||||
const int dD, const int dH, const int dW) {
|
const int dD, const int dH, const int dW,
|
||||||
|
const bool isNCDHW) {
|
||||||
|
|
||||||
// input and gradI [bS, iD, iH, iW, iC], mkl doesn't support ndhwc format
|
// weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support [kD, kH, kW, oC, iC], so we'll perform permutation
|
||||||
// weights and gradW [oC, iC, kD, kH, kW] always, mkl doesn't support weights format [kD, kH, kW, oC, iC]
|
|
||||||
// gradB [oC], may be nullptr
|
|
||||||
// gradO [bS, oD, oH, oW, oC]
|
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(true, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWoC, indWiC, indWkD);
|
||||||
|
|
||||||
dnnl::memory::dims strides = { sD, sH, sW };
|
dnnl::memory::dims strides = { sD, sH, sW };
|
||||||
dnnl::memory::dims padding = { pD, pH, pW };
|
dnnl::memory::dims padding = { pD, pH, pW };
|
||||||
|
@ -213,7 +211,7 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// gradB type
|
// gradB type
|
||||||
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::ncdhw; // isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
dnnl::memory::format_tag xFormat = isNCDHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::oidhw;
|
||||||
|
|
||||||
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
dnnl::memory::dims xDims = {bS, iC, iD, iH, iW};
|
||||||
|
@ -225,52 +223,58 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
||||||
|
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->stridesOf()[0];
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->stridesOf()[1];
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->stridesOf()[2];
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->stridesOf()[3];
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||||
x_user_md.data.format_desc.blocking.strides[4] = input->stridesOf()[4];
|
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
// weights
|
// weights
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormat);
|
||||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
w_user_md.data.format_desc.blocking.strides[0] = weights->stridesOf()[0];
|
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||||
w_user_md.data.format_desc.blocking.strides[1] = weights->stridesOf()[1];
|
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(4);
|
||||||
w_user_md.data.format_desc.blocking.strides[2] = weights->stridesOf()[2];
|
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||||
w_user_md.data.format_desc.blocking.strides[3] = weights->stridesOf()[3];
|
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||||
w_user_md.data.format_desc.blocking.strides[4] = weights->stridesOf()[4];
|
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(2);
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
||||||
|
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->stridesOf()[0];
|
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->stridesOf()[1];
|
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->stridesOf()[2];
|
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->stridesOf()[3];
|
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->stridesOf()[4];
|
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
||||||
|
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->stridesOf()[0];
|
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->stridesOf()[1];
|
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->stridesOf()[2];
|
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->stridesOf()[3];
|
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->stridesOf()[4];
|
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(4);
|
||||||
|
}
|
||||||
|
|
||||||
// gradW
|
// gradW
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormat);
|
||||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->stridesOf()[0];
|
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(3); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->stridesOf()[1];
|
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(4);
|
||||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->stridesOf()[2];
|
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(0);
|
||||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->stridesOf()[3];
|
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(1);
|
||||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->stridesOf()[4];
|
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(2);
|
||||||
|
|
||||||
// gradB
|
// gradB
|
||||||
dnnl::memory::desc gradB_mkl_md;
|
dnnl::memory::desc gradB_mkl_md;
|
||||||
|
@ -317,11 +321,15 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
if (gradOReorder)
|
auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
if (gradOReorderW)
|
||||||
|
dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
|
||||||
|
if (gradOReorderD)
|
||||||
|
dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
|
||||||
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||||
|
@ -344,6 +352,9 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// run backward data calculations
|
// run backward data calculations
|
||||||
dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
dnnl::deconvolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
if(gradOReorderW || gradOReorderD)
|
||||||
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
|
||||||
|
|
||||||
// run backward weights calculations
|
// run backward weights calculations
|
||||||
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
@ -400,32 +411,12 @@ PLATFORM_IMPL(deconv3d, ENGINE_CPU) {
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
}
|
}
|
||||||
|
|
||||||
// mkl supports only [oC, iC, kD, kH, kW] format for weights
|
deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW);
|
||||||
weights = new NDArray(weights->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
|
||||||
|
|
||||||
// mkl supports only NCDHW
|
|
||||||
if(!isNCDHW) {
|
|
||||||
input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
output = new NDArray(output->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
deconv3dMKLDNN(input, weights, bias, output, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
|
||||||
|
|
||||||
delete weights;
|
|
||||||
|
|
||||||
if(!isNCDHW) {
|
|
||||||
delete input;
|
|
||||||
delete output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
PLATFORM_CHECK(deconv3d, ENGINE_CPU) {
|
PLATFORM_CHECK(deconv3d, ENGINE_CPU) {
|
||||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
|
||||||
// if (::optimalLevel() < 2)
|
|
||||||
// return false;
|
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto weights = INPUT_VARIABLE(1);
|
auto weights = INPUT_VARIABLE(1);
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr;
|
||||||
|
@ -499,27 +490,7 @@ PLATFORM_IMPL(deconv3d_bp, ENGINE_CPU) {
|
||||||
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
if(isSameMode) // Note: we're intentionally swapping iH and oH, to calculated the padding for a"normal" conv (not deconv) forward pass
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, iD, iH, iW, oD, oH, oW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
|
|
||||||
// mkl supports only [oC, iC, kD, kH, kW] for weights
|
deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, isNCDHW);
|
||||||
weights = new NDArray(weights->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
|
||||||
gradW = new NDArray(gradW->permute({3,4,0,1,2})); // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
|
||||||
|
|
||||||
// mkl supports NCDHW format only
|
|
||||||
if(!isNCDHW) {
|
|
||||||
input = new NDArray(input->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
gradI = new NDArray(gradI->permute({0,4,1,2,3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
gradO = new NDArray(gradO->permute({0,4,1,2,3})); // [bS, oD, oH, oW, oC] -> [bS, oC, oD, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
deconv3dBackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW);
|
|
||||||
|
|
||||||
delete weights;
|
|
||||||
delete gradW;
|
|
||||||
|
|
||||||
if(!isNCDHW) {
|
|
||||||
delete input;
|
|
||||||
delete gradI;
|
|
||||||
delete gradO;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,7 +86,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
else
|
else
|
||||||
zType = dnnl::memory::data_type::s32;
|
zType = dnnl::memory::data_type::s32;
|
||||||
|
|
||||||
dnnl::memory::format_tag xzFrmat = dnnl::memory::format_tag::nchw;
|
dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw;
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw;
|
||||||
|
|
||||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||||
|
@ -98,11 +98,13 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format NHWC -> NCHW
|
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||||
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 : 3);
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1); // do permutation NHWC -> NCHW
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
|
@ -122,11 +124,13 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
||||||
// output
|
// output
|
||||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat);
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFrmat);
|
||||||
|
if(output->ews() != 1 || output->ordering() != 'c') {
|
||||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 : 3);
|
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(1); // do permutation NHWC -> NCHW
|
||||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1);
|
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(2);
|
||||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2);
|
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
@ -219,7 +223,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
// gradB type
|
// gradB type
|
||||||
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
dnnl::memory::data_type gradBType = gradB != nullptr ? (gradB->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16) : dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
dnnl::memory::format_tag xFormat = dnnl::memory::format_tag::nchw; // isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
dnnl::memory::format_tag xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw;
|
dnnl::memory::format_tag wFormat = dnnl::memory::format_tag::goihw;
|
||||||
|
|
||||||
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
dnnl::memory::dims xDims = {bS, iC, iH, iW};
|
||||||
|
@ -230,12 +234,14 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
|
|
||||||
// input
|
// input
|
||||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormat);
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFrmat);
|
||||||
|
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 : 3);
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(1);
|
||||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(2);
|
||||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
// weights, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||||
|
@ -249,21 +255,25 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormat);
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFrmat);
|
||||||
|
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 : 3);
|
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(1);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1);
|
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(2);
|
||||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2);
|
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormat);
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFrmat);
|
||||||
|
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 : 3);
|
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(1);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1);
|
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(2);
|
||||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2);
|
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(3);
|
||||||
|
}
|
||||||
|
|
||||||
// gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
// gradW, make permute [kH, kW, iC, mC] -> [iC, mC, 1, kH, kW];
|
||||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||||
|
@ -319,11 +329,15 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
|
|
||||||
// gradO
|
// gradO
|
||||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
const bool gradOReorder = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
const bool gradOReorderW = op_weights_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
const bool gradOReorderD = op_data_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
if (gradOReorder)
|
auto gradO_mkl_memW = gradOReorderW ? dnnl::memory(op_weights_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
auto gradO_mkl_memD = gradOReorderD ? dnnl::memory(op_data_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
if (gradOReorderW)
|
||||||
|
dnnl::reorder(gradO_user_mem, gradO_mkl_memW).execute(stream, gradO_user_mem, gradO_mkl_memW);
|
||||||
|
if (gradOReorderD)
|
||||||
|
dnnl::reorder(gradO_user_mem, gradO_mkl_memD).execute(stream, gradO_user_mem, gradO_mkl_memD);
|
||||||
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||||
|
|
||||||
// gradI
|
// gradI
|
||||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||||
|
@ -346,6 +360,9 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
||||||
// run backward data calculations
|
// run backward data calculations
|
||||||
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
if(gradOReorderW || gradOReorderD)
|
||||||
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memW;
|
||||||
|
|
||||||
// run backward weights calculations
|
// run backward weights calculations
|
||||||
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
@ -401,9 +418,6 @@ PLATFORM_IMPL(depthwise_conv2d, ENGINE_CPU) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) {
|
PLATFORM_CHECK(depthwise_conv2d, ENGINE_CPU) {
|
||||||
// we don't want to use mkldnn if cpu doesn't support avx/avx2
|
|
||||||
if (::optimalLevel() < 2)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto weights = INPUT_VARIABLE(1);
|
auto weights = INPUT_VARIABLE(1);
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
//
|
//
|
||||||
// @author saudet
|
// @author saudet
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/PlatformHelper.h>
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
@ -33,105 +34,38 @@ namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace platforms {
|
namespace platforms {
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
|
PLATFORM_IMPL(maxpool2d, ENGINE_CPU) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0, "Input should have rank of 4, but got %i instead",
|
|
||||||
input->rankOf());
|
|
||||||
|
|
||||||
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
|
||||||
auto argI = *(block.getIArguments());
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
const auto kH = INT_ARG(0);
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D MKLDNN OP: input array should have rank of 4, but got %i instead", input->rankOf());
|
||||||
const auto kW = INT_ARG(1);
|
|
||||||
const auto sH = INT_ARG(2);
|
// 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
const auto sW = INT_ARG(3);
|
const int kH = INT_ARG(0);
|
||||||
|
const int kW = INT_ARG(1);
|
||||||
|
const int sH = INT_ARG(2);
|
||||||
|
const int sW = INT_ARG(3);
|
||||||
int pH = INT_ARG(4);
|
int pH = INT_ARG(4);
|
||||||
int pW = INT_ARG(5);
|
int pW = INT_ARG(5);
|
||||||
const auto dH = INT_ARG(6);
|
const int dH = INT_ARG(6);
|
||||||
const auto dW = INT_ARG(7);
|
const int dW = INT_ARG(7);
|
||||||
const auto isSameMode = static_cast<bool>(INT_ARG(8));
|
const int paddingMode = INT_ARG(8);
|
||||||
|
// const int extraParam0 = INT_ARG(9);
|
||||||
|
const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D op: dilation must not be zero, but got instead {%i, %i}",
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
dH, dW);
|
|
||||||
|
|
||||||
int oH = 0;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int oW = 0;
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
if (paddingMode)
|
||||||
|
|
||||||
const int iH = static_cast<int>(isNCHW ? input->sizeAt(2) : input->sizeAt(1));
|
|
||||||
const int iW = static_cast<int>(isNCHW ? input->sizeAt(3) : input->sizeAt(2));
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
|
||||||
input = new NDArray(
|
|
||||||
input->permute({0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
output = new NDArray(
|
|
||||||
output->permute({0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
ConvolutionUtils::calcOutSizePool2D(oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, iH, iW, isSameMode);
|
|
||||||
|
|
||||||
if (isSameMode)
|
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
const int bS = input->sizeAt(0);
|
mkldnnUtils::poolingMKLDNN(input, output, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, algorithm::pooling_max);
|
||||||
const int iC = input->sizeAt(1);
|
|
||||||
const int oC = output->sizeAt(1);
|
|
||||||
|
|
||||||
auto poolingMode = PoolingType::MAX_POOL;
|
|
||||||
int extraParam0 = 1;
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
|
||||||
true,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, output,
|
|
||||||
algorithm,
|
|
||||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
|
||||||
&user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
|
||||||
pool_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
|
||||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
|
||||||
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
|
||||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
|
||||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool_dst_memory = user_dst_memory;
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
|
||||||
{DNNL_ARG_DST, pool_dst_memory}});
|
|
||||||
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
|
||||||
delete input;
|
|
||||||
delete output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -159,117 +93,24 @@ PLATFORM_IMPL(maxpool2d_bp, ENGINE_CPU) {
|
||||||
int pW = INT_ARG(5); // paddings width
|
int pW = INT_ARG(5); // paddings width
|
||||||
int dH = INT_ARG(6); // dilations height
|
int dH = INT_ARG(6); // dilations height
|
||||||
int dW = INT_ARG(7); // dilations width
|
int dW = INT_ARG(7); // dilations width
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
int extraParam0 = INT_ARG(9);
|
// int extraParam0 = INT_ARG(9);
|
||||||
int isNCHW =
|
int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
||||||
block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D_BP MKLDNN op: input should have rank of 4, but got %i instead", input->rankOf());
|
||||||
"AVGPOOL2D_BP op: input should have rank of 4, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
||||||
REQUIRE_TRUE(dH != 0 && dW != 0, 0,
|
|
||||||
"AVGPOOL2D_BP op: dilation must not be zero, but got instead {%i, %i}", dH, dW);
|
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1});
|
||||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, oH, oW, 0, indIOioC, indIiH, indIiH + 1}));
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL2D_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(
|
|
||||||
ShapeUtils::composeShapeUsingDimsAndIdx({bS, iC, iH, iW, 0, indIOioC, indIiH, indIiH + 1}));
|
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
|
||||||
"AVGPOOL2D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
|
||||||
"AVGPOOL2D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
|
||||||
|
|
||||||
|
if (paddingMode) // SAME
|
||||||
if (!isNCHW) {
|
|
||||||
input = new NDArray(input->permute(
|
|
||||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
gradI = new NDArray(gradI->permute(
|
|
||||||
{0, 3, 1, 2})); // [bS, iH, iW, iC] -> [bS, iC, iH, iW]
|
|
||||||
gradO = new NDArray(gradO->permute(
|
|
||||||
{0, 3, 1, 2})); // [bS, oH, oW, iC] -> [bS, iC, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
||||||
|
|
||||||
auto poolingMode = PoolingType::MAX_POOL;
|
mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, 0,kH,kW, 0,sH,sW, 0,pH,pW, isNCHW, algorithm::pooling_max);
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool2d(kH, kW, sH, sW, pH, pW, dH, dW, poolingMode, extraParam0,
|
|
||||||
true,
|
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, gradO, algorithm,
|
|
||||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
|
||||||
&user_diff_src_md, &user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
// input is sometimes null, so we can't rely on pool_src_md being valid
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm,
|
|
||||||
input->buffer() != nullptr ? pool_src_md : pool_diff_src_md,
|
|
||||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
|
||||||
pool_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
|
|
||||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
|
||||||
auto userB_src_memory = dnnl::memory(user_src_md, engine, gradI->buffer());
|
|
||||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
|
||||||
|
|
||||||
auto poolB_src_memory = userB_src_memory;
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto poolB_dst_memory = userB_dst_memory;
|
|
||||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
|
||||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
|
||||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
|
||||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
|
||||||
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
|
||||||
|
|
||||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
|
||||||
{DNNL_ARG_DST, pool_dst_memory},
|
|
||||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
|
||||||
// probably wrong, fix that
|
|
||||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
|
||||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
|
||||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
|
||||||
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
if (!isNCHW) {
|
|
||||||
delete input;
|
|
||||||
delete gradI;
|
|
||||||
delete gradO;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author raver119@gmail.com
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/PlatformHelper.h>
|
#include <ops/declarable/PlatformHelper.h>
|
||||||
|
@ -34,10 +35,9 @@ namespace platforms {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
|
PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
|
||||||
auto input = INPUT_VARIABLE(
|
|
||||||
0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
auto output = OUTPUT_VARIABLE(
|
auto output = OUTPUT_VARIABLE(0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
||||||
0); // [bS, oD, oH, oW, iC] (NDHWC) or [bS, iC, oD, oH, oW] (NCDHW)
|
|
||||||
|
|
||||||
int kD = INT_ARG(0); // filter(kernel) depth
|
int kD = INT_ARG(0); // filter(kernel) depth
|
||||||
int kH = INT_ARG(1); // filter(kernel) height
|
int kH = INT_ARG(1); // filter(kernel) height
|
||||||
|
@ -51,95 +51,24 @@ PLATFORM_IMPL(maxpool3dnew, ENGINE_CPU) {
|
||||||
int dD = INT_ARG(9); // dilations depth
|
int dD = INT_ARG(9); // dilations depth
|
||||||
int dH = INT_ARG(10); // dilations height
|
int dH = INT_ARG(10); // dilations height
|
||||||
int dW = INT_ARG(11); // dilations width
|
int dW = INT_ARG(11); // dilations width
|
||||||
int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
int paddingMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW MKLDNN OP: rank of input array must be equal to 5, but got %i instead !", input->rankOf());
|
||||||
"MAXPOOL3DNEW OP: rank of input array must be equal to 5, but got %i instead !",
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
input->rankOf());
|
|
||||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
|
||||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
|
||||||
|
|
||||||
std::string expectedOutputShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
if(paddingMode) // SAME
|
||||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
REQUIRE_TRUE(expectedOutputShape == ShapeUtils::shapeAsString(output), 0,
|
|
||||||
"MAXPOOL3D op: wrong shape of output array, expected is %s, but got %s instead !",
|
|
||||||
expectedOutputShape.c_str(), ShapeUtils::shapeAsString(output).c_str());
|
|
||||||
// REQUIRE_TRUE(iD >= kD && iH >= kH && iW >= kW, 0, "MAXPOOL3D OP: the input depth/height/width must be greater or equal to kernel(filter) depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", iD,iH,iW, kD,kH,kW);
|
|
||||||
// REQUIRE_TRUE(kD/2 >= pD && kH/2 >= pH && kW/2 >= pW, 0, "MAXPOOL3D OP: pad depth/height/width must not be greater than half of kernel depth/height/width, but got [%i, %i, %i] and [%i, %i, %i] correspondingly !", pD,pH,pW, kD,kH,kW);
|
|
||||||
|
|
||||||
if (!isNCDHW) {
|
mkldnnUtils::poolingMKLDNN(input, output, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, algorithm::pooling_max);
|
||||||
input = new NDArray(
|
|
||||||
input->permute({0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
output = new NDArray(
|
|
||||||
output->permute({0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
|
||||||
dW);
|
|
||||||
|
|
||||||
|
|
||||||
auto poolingMode = PoolingType::MAX_POOL;
|
|
||||||
auto extraParam0 = 1;
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
|
||||||
extraParam0, true,
|
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, nullptr, output,
|
|
||||||
algorithm,
|
|
||||||
&pool_src_md, nullptr, &pool_dst_md, &user_src_md, nullptr,
|
|
||||||
&user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward_inference, algorithm, pool_src_md,
|
|
||||||
pool_dst_md, pool_strides, pool_kernel, pool_padding,
|
|
||||||
pool_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
|
||||||
auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer());
|
|
||||||
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
|
||||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
|
||||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool_dst_memory = user_dst_memory;
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
|
||||||
{DNNL_ARG_DST, pool_dst_memory}});
|
|
||||||
|
|
||||||
if (pool_prim_desc.dst_desc() != user_dst_memory.get_desc()) {
|
|
||||||
reorder(pool_dst_memory, user_dst_memory).execute(stream, pool_dst_memory, user_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
|
|
||||||
if (!isNCDHW) {
|
|
||||||
delete input;
|
|
||||||
delete output;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -152,6 +81,7 @@ PLATFORM_CHECK(maxpool3dnew, ENGINE_CPU) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
|
PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
auto input = INPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW)
|
||||||
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
auto gradO = INPUT_VARIABLE(1); // [bS, oD, oH, oW, oC] (NDHWC) or [bS, oC, oD, oH, oW] (NCDHW), epsilon_next
|
||||||
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iD, iH, iW, iC] (NDHWC) or [bS, iC, iD, iH, iW] (NCDHW), epsilon
|
||||||
|
@ -168,121 +98,24 @@ PLATFORM_IMPL(maxpool3dnew_bp, ENGINE_CPU) {
|
||||||
const int dD = INT_ARG(9); // dilations depth
|
const int dD = INT_ARG(9); // dilations depth
|
||||||
const int dH = INT_ARG(10); // dilations height
|
const int dH = INT_ARG(10); // dilations height
|
||||||
const int dW = INT_ARG(11); // dilations width
|
const int dW = INT_ARG(11); // dilations width
|
||||||
const int isSameMode = INT_ARG(12); // 1-SAME, 0-VALID
|
const int paddngMode = INT_ARG(12); // 1-SAME, 0-VALID
|
||||||
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
// int extraParam0 = INT_ARG(13); // unnecessary for max case, required only for avg and pnorm cases
|
||||||
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
int isNCDHW = block.getIArguments()->size() > 14 ? !INT_ARG(14) : 1; // 1-NDHWC, 0-NCDHW
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 5, 0,
|
REQUIRE_TRUE(input->rankOf() == 5, 0, "MAXPOOL3DNEW_BP MKLDNN op: input should have rank of 5, but got %i instead", input->rankOf());
|
||||||
"MAXPOOL3D_BP op: input should have rank of 5, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0, "MAXPOOL3DNEW_BP MKLDNN op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
||||||
REQUIRE_TRUE(dD != 0 && dH != 0 && dW != 0, 0,
|
|
||||||
"MAXPOOL3DNEW op: dilation must not be zero, but got instead {%i, %i, %i}", dD, dH, dW);
|
|
||||||
|
|
||||||
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW; // batch size, input channels, input depth/height/width, output channels, output depth/height/width;
|
||||||
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
int indIOioC, indIOioD, indWoC, indWiC, indWkD; // corresponding indexes
|
||||||
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW,
|
ConvolutionUtils::getSizesAndIndexesConv3d(isNCDHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
||||||
indIOioC, indIOioD, indWiC, indWoC, indWkD);
|
|
||||||
|
|
||||||
std::string expectedGradOShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
std::vector<Nd4jLong> expectedGradOShape = ShapeUtils::composeShapeUsingDimsAndIdx({bS,iC,oD,oH,oW, 0,indIOioC,indIOioD,indIOioD+1,indIOioD+2});
|
||||||
{bS, iC, oD, oH, oW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
REQUIRE_TRUE(gradO->isSameShape(expectedGradOShape), 0, "MAXPOOL3DNEW_BP MKLDNN op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !", ShapeUtils::shapeAsString(expectedGradOShape).c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
||||||
std::string expectedGradIShape = ShapeUtils::shapeAsString(ShapeUtils::composeShapeUsingDimsAndIdx(
|
|
||||||
{bS, iC, iD, iH, iW, 0, indIOioC, indIOioD, indIOioD + 1, indIOioD + 2}));
|
|
||||||
REQUIRE_TRUE(expectedGradOShape == ShapeUtils::shapeAsString(gradO), 0,
|
|
||||||
"MAXPOOL3D_BP op: wrong shape of output's gradients array (next epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradOShape.c_str(), ShapeUtils::shapeAsString(gradO).c_str());
|
|
||||||
REQUIRE_TRUE(expectedGradIShape == ShapeUtils::shapeAsString(gradI), 0,
|
|
||||||
"MAXPOOL3D_BP op: wrong shape of input's gradients array (epsilon), expected is %s, but got %s instead !",
|
|
||||||
expectedGradIShape.c_str(), ShapeUtils::shapeAsString(gradI).c_str());
|
|
||||||
|
|
||||||
if (!isNCDHW) {
|
if(paddngMode) // SAME
|
||||||
input = new NDArray(input->permute(
|
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH, dW);
|
||||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
gradI = new NDArray(gradI->permute(
|
|
||||||
{0, 4, 1, 2, 3})); // [bS, iD, iH, iW, iC] -> [bS, iC, iD, iH, iW]
|
|
||||||
gradO = new NDArray(gradO->permute(
|
|
||||||
{0, 4, 1, 2, 3})); // [bS, oD, oH, oW, iC] -> [bS, iC, oD, oH, oW]
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
mkldnnUtils::poolingBpMKLDNN(input, gradO, gradI, kD,kH,kW, sD,sH,sW, pD,pH,pW, isNCDHW, algorithm::pooling_max);
|
||||||
ConvolutionUtils::calcPadding3D(pD, pH, pW, oD, oH, oW, iD, iH, iW, kD, kH, kW, sD, sH, sW, dD, dH,
|
|
||||||
dW);
|
|
||||||
|
|
||||||
|
|
||||||
auto poolingMode = PoolingType::MAX_POOL;
|
|
||||||
auto extraParam0 = 1;
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
|
||||||
dnnl::memory::desc pool_src_md(empty), pool_diff_src_md(empty), pool_dst_md(empty);
|
|
||||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_dst_md(empty);
|
|
||||||
dnnl::memory::dims pool_strides, pool_kernel, pool_padding, pool_padding_r;
|
|
||||||
dnnl::algorithm algorithm;
|
|
||||||
|
|
||||||
mkldnnUtils::getMKLDNNMemoryDescPool3d(kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, poolingMode,
|
|
||||||
extraParam0, true,
|
|
||||||
bS, iC, iD, iH, iW, oC, oD, oH, oW, input, gradI, gradO,
|
|
||||||
algorithm,
|
|
||||||
&pool_src_md, &pool_diff_src_md, &pool_dst_md, &user_src_md,
|
|
||||||
&user_diff_src_md, &user_dst_md,
|
|
||||||
pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
// input is sometimes null, so we can't rely on pool_src_md being valid
|
|
||||||
if (input->buffer() == nullptr) {
|
|
||||||
pool_src_md = pool_diff_src_md;
|
|
||||||
user_src_md = user_diff_src_md;
|
|
||||||
}
|
|
||||||
auto pool_desc = pooling_forward::desc(prop_kind::forward, algorithm, pool_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
|
||||||
dnnl::stream stream(engine);
|
|
||||||
auto pool_prim_desc = pooling_forward::primitive_desc(pool_desc, engine);
|
|
||||||
|
|
||||||
auto poolB_desc = pooling_backward::desc(algorithm, pool_diff_src_md, pool_dst_md, pool_strides, pool_kernel, pool_padding, pool_padding_r);
|
|
||||||
|
|
||||||
auto poolB_prim_desc = pooling_backward::primitive_desc(poolB_desc, engine, pool_prim_desc);
|
|
||||||
auto userB_src_memory = dnnl::memory(user_diff_src_md, engine, gradI->buffer());
|
|
||||||
auto userB_dst_memory = dnnl::memory(user_dst_md, engine, gradO->buffer());
|
|
||||||
|
|
||||||
auto poolB_src_memory = userB_src_memory;
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
poolB_src_memory = dnnl::memory(poolB_prim_desc.diff_src_desc(), engine);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto poolB_dst_memory = userB_dst_memory;
|
|
||||||
if (poolB_prim_desc.diff_dst_desc() != userB_dst_memory.get_desc()) {
|
|
||||||
poolB_dst_memory = dnnl::memory(poolB_prim_desc.diff_dst_desc(), engine);
|
|
||||||
reorder(userB_dst_memory, poolB_dst_memory).execute(stream, userB_dst_memory, poolB_dst_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer());
|
|
||||||
|
|
||||||
auto pool_src_memory = user_src_memory;
|
|
||||||
if (pool_prim_desc.src_desc() != user_src_memory.get_desc()) {
|
|
||||||
pool_src_memory = dnnl::memory(pool_prim_desc.src_desc(), engine);
|
|
||||||
reorder(user_src_memory, pool_src_memory).execute(stream, user_src_memory, pool_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto pool_dst_memory = dnnl::memory(pool_prim_desc.dst_desc(), engine);
|
|
||||||
auto pool_workspace_memory = dnnl::memory(pool_prim_desc.workspace_desc(), engine);
|
|
||||||
|
|
||||||
pooling_forward(pool_prim_desc).execute(stream, {{DNNL_ARG_SRC, pool_src_memory},
|
|
||||||
{DNNL_ARG_DST, pool_dst_memory},
|
|
||||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory}});
|
|
||||||
pooling_backward(poolB_prim_desc).execute(stream, {{DNNL_ARG_DIFF_DST, poolB_dst_memory},
|
|
||||||
{DNNL_ARG_WORKSPACE, pool_workspace_memory},
|
|
||||||
{DNNL_ARG_DIFF_SRC, poolB_src_memory}});
|
|
||||||
|
|
||||||
|
|
||||||
if (poolB_prim_desc.diff_src_desc() != userB_src_memory.get_desc()) {
|
|
||||||
reorder(poolB_src_memory, userB_src_memory).execute(stream, poolB_src_memory, userB_src_memory);
|
|
||||||
}
|
|
||||||
|
|
||||||
stream.wait();
|
|
||||||
|
|
||||||
if (!isNCDHW) {
|
|
||||||
delete input;
|
|
||||||
delete gradI;
|
|
||||||
delete gradO;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,9 +16,11 @@
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author saudet
|
// @author saudet
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <dnnl_types.h>
|
#include <dnnl_types.h>
|
||||||
|
#include <ops/declarable/helpers/convolutions.h>
|
||||||
#include "mkldnnUtils.h"
|
#include "mkldnnUtils.h"
|
||||||
|
|
||||||
using namespace dnnl;
|
using namespace dnnl;
|
||||||
|
@ -26,6 +28,314 @@ using namespace dnnl;
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace mkldnnUtils {
|
namespace mkldnnUtils {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
void poolingMKLDNN(const NDArray *input, NDArray *output,
|
||||||
|
const int kD, const int kH, const int kW,
|
||||||
|
const int sD, const int sH, const int sW,
|
||||||
|
const int pD, const int pH, const int pW,
|
||||||
|
const int isNCHW, const dnnl::algorithm mode) {
|
||||||
|
|
||||||
|
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for input
|
||||||
|
const int rank = input->rankOf();
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;
|
||||||
|
dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims;
|
||||||
|
dnnl::memory::format_tag xzFrmat;
|
||||||
|
|
||||||
|
const auto type = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
if(rank == 4) { // 2d
|
||||||
|
|
||||||
|
ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
strides = { sH, sW };
|
||||||
|
kernel = { kH, kW };
|
||||||
|
padding = { pH, pW };
|
||||||
|
padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||||
|
xDims = {bS, iC, iH, iW};
|
||||||
|
zDims = {bS, oC, oH, oW};
|
||||||
|
|
||||||
|
xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
|
}
|
||||||
|
else { // 3d
|
||||||
|
|
||||||
|
ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *output, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH);
|
||||||
|
|
||||||
|
strides = { sD, sH, sW };
|
||||||
|
kernel = { kD, kH, kW };
|
||||||
|
padding = { pD, pH, pW };
|
||||||
|
padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||||
|
xDims = {bS, iC, iD, iH, iW};
|
||||||
|
zDims = {bS, oC, oD, oH, oW};
|
||||||
|
|
||||||
|
xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
|
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||||
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
||||||
|
if(rank == 5)
|
||||||
|
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// output
|
||||||
|
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||||
|
if(output->ews() != 1 || output->ordering() != 'c') {
|
||||||
|
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 :-1);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1);
|
||||||
|
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2);
|
||||||
|
if(rank == 5)
|
||||||
|
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(isNCHW ? 4 : 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
|
||||||
|
// operation primitive description
|
||||||
|
dnnl::pooling_forward::desc op_desc(dnnl::prop_kind::forward_inference, mode, x_mkl_md, z_mkl_md, strides, kernel, padding, padding_r);
|
||||||
|
dnnl::pooling_forward::primitive_desc op_prim_desc(op_desc, engine);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// provide memory buffers and check whether reorder is required
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||||
|
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// output
|
||||||
|
auto z_user_mem = dnnl::memory(z_user_md, engine, output->getBuffer());
|
||||||
|
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||||
|
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||||
|
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||||
|
|
||||||
|
// run calculations
|
||||||
|
dnnl::pooling_forward(op_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
// reorder outputs if necessary
|
||||||
|
if (zReorder)
|
||||||
|
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
||||||
|
const int kD, const int kH, const int kW,
|
||||||
|
const int sD, const int sH, const int sW,
|
||||||
|
const int pD, const int pH, const int pW,
|
||||||
|
const int isNCHW, const dnnl::algorithm mode) {
|
||||||
|
|
||||||
|
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for input
|
||||||
|
|
||||||
|
const int rank = input->rankOf();
|
||||||
|
|
||||||
|
int bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH;
|
||||||
|
dnnl::memory::dims strides, kernel, padding, padding_r, xDims, zDims;
|
||||||
|
dnnl::memory::format_tag xzFrmat;
|
||||||
|
|
||||||
|
const auto type = dnnl::memory::data_type::f32;
|
||||||
|
|
||||||
|
if(rank == 4) { // 2d
|
||||||
|
|
||||||
|
ops::ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
|
strides = { sH, sW };
|
||||||
|
kernel = { kH, kW };
|
||||||
|
padding = { pH, pW };
|
||||||
|
padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||||
|
xDims = {bS, iC, iH, iW};
|
||||||
|
zDims = {bS, oC, oH, oW};
|
||||||
|
|
||||||
|
xzFrmat = isNCHW ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
|
}
|
||||||
|
else { // 3d
|
||||||
|
|
||||||
|
ops::ConvolutionUtils::getSizesAndIndexesConv3d(isNCHW, *input, *gradO, bS, iC, iD, iH, iW, oC, oD, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH);
|
||||||
|
|
||||||
|
strides = { sD, sH, sW };
|
||||||
|
kernel = { kD, kH, kW };
|
||||||
|
padding = { pD, pH, pW };
|
||||||
|
padding_r = { (oD - 1) * sD - iD + kD - pD, (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
||||||
|
xDims = {bS, iC, iD, iH, iW};
|
||||||
|
zDims = {bS, oC, oD, oH, oW};
|
||||||
|
|
||||||
|
xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||||
|
}
|
||||||
|
|
||||||
|
// memory descriptors for arrays
|
||||||
|
|
||||||
|
// input
|
||||||
|
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
|
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
|
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||||
|
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
||||||
|
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
||||||
|
if(rank == 5)
|
||||||
|
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||||
|
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||||
|
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 :-1);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1);
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2);
|
||||||
|
if(rank == 5)
|
||||||
|
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(isNCHW ? 4 : 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||||
|
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||||
|
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||||
|
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 :-1);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1);
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2);
|
||||||
|
if(rank == 5)
|
||||||
|
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(isNCHW ? 4 : 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||||
|
dnnl::stream stream(engine);
|
||||||
|
|
||||||
|
// forward primitive description
|
||||||
|
dnnl::pooling_forward::desc op_ff_desc(dnnl::prop_kind::forward, mode, x_mkl_md, gradO_mkl_md, strides, kernel, padding, padding_r);
|
||||||
|
dnnl::pooling_forward::primitive_desc op_ff_prim_desc(op_ff_desc, engine);
|
||||||
|
|
||||||
|
// backward primitive description
|
||||||
|
dnnl::pooling_backward::desc op_bp_desc(mode, gradI_mkl_md, gradO_mkl_md, strides, kernel, padding, padding_r);
|
||||||
|
dnnl::pooling_backward::primitive_desc op_bp_prim_desc(op_bp_desc, engine, op_ff_prim_desc);
|
||||||
|
|
||||||
|
// arguments (memory buffers) necessary for calculations
|
||||||
|
std::unordered_map<int, dnnl::memory> args;
|
||||||
|
|
||||||
|
// gradO
|
||||||
|
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, gradO->getBuffer());
|
||||||
|
const bool gradOReorder = op_bp_prim_desc.diff_dst_desc() != gradO_user_mem.get_desc();
|
||||||
|
auto gradO_mkl_mem = gradOReorder ? dnnl::memory(op_bp_prim_desc.diff_dst_desc(), engine) : gradO_user_mem;
|
||||||
|
if (gradOReorder)
|
||||||
|
dnnl::reorder(gradO_user_mem, gradO_mkl_mem).execute(stream, gradO_user_mem, gradO_mkl_mem);
|
||||||
|
args[DNNL_ARG_DIFF_DST] = gradO_mkl_mem;
|
||||||
|
|
||||||
|
// gradI
|
||||||
|
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->getBuffer());
|
||||||
|
const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||||
|
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||||
|
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||||
|
|
||||||
|
if(mode == algorithm::pooling_max) {
|
||||||
|
|
||||||
|
// input
|
||||||
|
auto x_user_mem = dnnl::memory(x_user_md, engine, input->getBuffer());
|
||||||
|
const bool xReorder = op_ff_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||||
|
auto x_mkl_mem = xReorder ? dnnl::memory(op_ff_prim_desc.src_desc(), engine) : x_user_mem;
|
||||||
|
if (xReorder)
|
||||||
|
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||||
|
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||||
|
|
||||||
|
// z
|
||||||
|
auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine);
|
||||||
|
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||||
|
|
||||||
|
// auxiliary memory allocation
|
||||||
|
auto workspace = dnnl::memory(op_ff_prim_desc.workspace_desc(), engine);
|
||||||
|
args[DNNL_ARG_WORKSPACE] = workspace;
|
||||||
|
|
||||||
|
// run forward calculations
|
||||||
|
dnnl::pooling_forward(op_ff_prim_desc).execute(stream, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
// run backward calculations
|
||||||
|
dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args);
|
||||||
|
|
||||||
|
|
||||||
|
// reorder gradI if necessary
|
||||||
|
if (gradIReorder)
|
||||||
|
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||||
|
|
||||||
|
stream.wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
|
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
||||||
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||||
|
const Nd4jLong* shape = src->getShapeInfo();
|
||||||
|
long rank = shape[0];
|
||||||
|
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||||
|
long dim2 = axis >= 2 ? 1 : 2;
|
||||||
|
long dim3 = axis >= 3 ? 2 : 3;
|
||||||
|
dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||||
|
|
||||||
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
||||||
|
auto supposed_to_be_any_format = format; // doesn't work with "any"
|
||||||
|
|
||||||
|
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
|
||||||
|
*lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||||
|
user_src_md->data.format_kind = dnnl_blocked;
|
||||||
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||||
|
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||||
|
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
|
||||||
|
*lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||||
|
user_diff_src_md->data.format_kind = dnnl_blocked;
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||||
|
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
|
||||||
|
*lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
||||||
|
*user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
||||||
|
user_dst_md->data.format_kind = dnnl_blocked;
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||||
|
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
dnnl::engine& getEngine(void *ptr) {
|
||||||
|
auto eng = reinterpret_cast<dnnl::engine*>(ptr);
|
||||||
|
return *eng;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
/*
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void getMKLDNNMemoryDescPool2d(
|
void getMKLDNNMemoryDescPool2d(
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, int poolingMode, int extraParam0, bool isNCHW,
|
||||||
|
@ -307,104 +617,51 @@ void getMKLDNNMemoryDescConv3d(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
// void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
||||||
// dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
|
||||||
// dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
|
||||||
// const Nd4jLong* shape = src->getShapeInfo();
|
|
||||||
// Nd4jLong rank = shape[0];
|
|
||||||
// Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
|
||||||
// Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
|
||||||
// Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
|
||||||
// dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
|
||||||
|
|
||||||
// auto type = dnnl::memory::data_type::f32;
|
|
||||||
// auto format = dnnl::memory::format_tag::nchw;
|
|
||||||
// auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
|
||||||
|
|
||||||
// if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
|
||||||
// *batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
// *user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
|
||||||
// user_src_md->data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
// user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
|
||||||
// user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
|
||||||
// user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
|
||||||
// user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
|
||||||
// *batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
// *user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
|
||||||
// user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
// user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
|
||||||
// user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
|
||||||
// user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
|
||||||
// user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
|
||||||
// }
|
|
||||||
|
|
||||||
// if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
|
||||||
// *batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
|
||||||
// *user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
|
||||||
// user_dst_md->data.format_kind = dnnl_blocked; // overrides format
|
|
||||||
// user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
|
||||||
// user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
|
||||||
// user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
|
||||||
// user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
|
||||||
// }
|
|
||||||
// };
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
|
||||||
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis) {
|
||||||
const Nd4jLong* shape = src->getShapeInfo();
|
const Nd4jLong* shape = src->getShapeInfo();
|
||||||
long rank = shape[0];
|
Nd4jLong rank = shape[0];
|
||||||
long dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
Nd4jLong dim1 = axis; // MKL-DNN supports only 1 axis, which has to be the "channel" one
|
||||||
long dim2 = axis >= 2 ? 1 : 2;
|
Nd4jLong dim2 = axis >= 2 ? 1 : 2;
|
||||||
long dim3 = axis >= 3 ? 2 : 3;
|
Nd4jLong dim3 = axis >= 3 ? 2 : 3;
|
||||||
dnnl::memory::dims lrn_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
dnnl::memory::dims batchnorm_src_tz = { (int)shape[1], (int)shape[dim1 + 1], rank > 2 ? (int)shape[dim2 + 1] : 1, rank > 3 ? (int)shape[dim3 + 1] : 1};
|
||||||
|
|
||||||
auto type = dnnl::memory::data_type::f32;
|
auto type = dnnl::memory::data_type::f32;
|
||||||
auto format = axis == 1 ? dnnl::memory::format_tag::nchw : dnnl::memory::format_tag::nhwc;
|
auto format = dnnl::memory::format_tag::nchw;
|
||||||
auto supposed_to_be_any_format = format; // doesn't work with "any"
|
auto supposed_to_be_any_format = dnnl::memory::format_tag::nChw8c; // doesn't work with "any"
|
||||||
|
|
||||||
if (src != nullptr && src->getBuffer() != nullptr && lrn_src_md != nullptr) {
|
if (src != nullptr && src->getBuffer() != nullptr && batchnorm_src_md != nullptr) {
|
||||||
*lrn_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
*batchnorm_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||||
*user_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
*user_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||||
user_src_md->data.format_kind = dnnl_blocked;
|
user_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||||
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
user_src_md->data.format_desc.blocking.strides[0] = src->stridesOf()[0];
|
||||||
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
user_src_md->data.format_desc.blocking.strides[1] = src->stridesOf()[dim1];
|
||||||
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
user_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? src->stridesOf()[dim2] : 1;
|
||||||
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
user_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? src->stridesOf()[dim3] : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && lrn_diff_src_md != nullptr) {
|
if (diff_src != nullptr && diff_src->getBuffer() != nullptr && batchnorm_diff_src_md != nullptr) {
|
||||||
*lrn_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
*batchnorm_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||||
*user_diff_src_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
*user_diff_src_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||||
user_diff_src_md->data.format_kind = dnnl_blocked;
|
user_diff_src_md->data.format_kind = dnnl_blocked; // overrides format
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
user_diff_src_md->data.format_desc.blocking.strides[0] = diff_src->stridesOf()[0];
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
user_diff_src_md->data.format_desc.blocking.strides[1] = diff_src->stridesOf()[dim1];
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
user_diff_src_md->data.format_desc.blocking.strides[2] = rank > 2 ? diff_src->stridesOf()[dim2] : 1;
|
||||||
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
user_diff_src_md->data.format_desc.blocking.strides[3] = rank > 3 ? diff_src->stridesOf()[dim3] : 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dst != nullptr && dst->getBuffer() != nullptr && lrn_dst_md != nullptr) {
|
if (dst != nullptr && dst->getBuffer() != nullptr && batchnorm_dst_md != nullptr) {
|
||||||
*lrn_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, supposed_to_be_any_format);
|
*batchnorm_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, supposed_to_be_any_format);
|
||||||
*user_dst_md = dnnl::memory::desc({ lrn_src_tz }, type, format);
|
*user_dst_md = dnnl::memory::desc({ batchnorm_src_tz }, type, format);
|
||||||
user_dst_md->data.format_kind = dnnl_blocked;
|
user_dst_md->data.format_kind = dnnl_blocked; // overrides format
|
||||||
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
user_dst_md->data.format_desc.blocking.strides[0] = dst->stridesOf()[0];
|
||||||
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
user_dst_md->data.format_desc.blocking.strides[1] = dst->stridesOf()[dim1];
|
||||||
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
user_dst_md->data.format_desc.blocking.strides[2] = rank > 2 ? dst->stridesOf()[dim2] : 1;
|
||||||
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
user_dst_md->data.format_desc.blocking.strides[3] = rank > 3 ? dst->stridesOf()[dim3] : 1;
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
|
*/
|
||||||
//////////////////////////////////////////////////////////////////////////
|
|
||||||
dnnl::engine& getEngine(void *ptr) {
|
|
||||||
auto eng = reinterpret_cast<dnnl::engine*>(ptr);
|
|
||||||
return *eng;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author saudet
|
// @author saudet
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#ifndef DEV_TESTS_MKLDNNUTILS_H
|
#ifndef DEV_TESTS_MKLDNNUTILS_H
|
||||||
|
@ -88,10 +89,20 @@ namespace nd4j{
|
||||||
|
|
||||||
namespace mkldnnUtils {
|
namespace mkldnnUtils {
|
||||||
|
|
||||||
|
void poolingMKLDNN(const NDArray *input, NDArray *output, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
|
||||||
|
|
||||||
|
void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI, const int kD, const int kH, const int kW, const int sD, const int sH, const int sW, const int pD, const int pH, const int pW, const int isNCHW, const dnnl::algorithm mode);
|
||||||
|
|
||||||
|
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
|
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
||||||
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
||||||
|
|
||||||
|
dnnl::engine& getEngine(void *ptr);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Utility methods for MKLDNN
|
* Utility methods for MKLDNN
|
||||||
*/
|
*/
|
||||||
void getMKLDNNMemoryDescConv2d(
|
/* void getMKLDNNMemoryDescConv2d(
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
|
||||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
|
@ -130,12 +141,7 @@ namespace nd4j{
|
||||||
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
void getMKLDNNMemoryDescBatchNorm(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
||||||
dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
dnnl::memory::desc* batchnorm_src_md, dnnl::memory::desc* batchnorm_diff_src_md, dnnl::memory::desc* batchnorm_dst_md,
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
||||||
|
*/
|
||||||
void getMKLDNNMemoryDescLrn(const NDArray* src, const NDArray* diff_src, const NDArray* dst,
|
|
||||||
dnnl::memory::desc* lrn_src_md, dnnl::memory::desc* lrn_diff_src_md, dnnl::memory::desc* lrn_dst_md,
|
|
||||||
dnnl::memory::desc* user_src_md, dnnl::memory::desc* user_diff_src_md, dnnl::memory::desc* user_dst_md, int axis);
|
|
||||||
|
|
||||||
dnnl::engine& getEngine(void *ptr);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2031,121 +2031,6 @@ TEST_F(DeclarableOpsTests1, Sum1) {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, Avgpool2d_test1) {
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
|
||||||
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
|
||||||
// auto z('c',{bS,iD,oH,oW});
|
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
|
||||||
variableSpace->putVariable(-1, x);
|
|
||||||
// variableSpace->putVariable(1, &z);
|
|
||||||
|
|
||||||
auto block = new Context(1, variableSpace, false);
|
|
||||||
block->fillInputs({-1});
|
|
||||||
std::vector<int>* argI = block->getIArguments();
|
|
||||||
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
|
||||||
|
|
||||||
nd4j::ops::avgpool2d pooling;
|
|
||||||
Nd4jStatus status = pooling.execute(block);
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
||||||
|
|
||||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
||||||
ASSERT_TRUE(exp.isSameShape(result));
|
|
||||||
|
|
||||||
|
|
||||||
delete variableSpace;
|
|
||||||
delete block;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, Avgpool2d_test2) {
|
|
||||||
const int bS = 2;
|
|
||||||
const int iD = 1;
|
|
||||||
const int iH = 28;
|
|
||||||
const int iW = 28;
|
|
||||||
const int kH = 5;
|
|
||||||
const int kW = 5;
|
|
||||||
const int sH = 1;
|
|
||||||
const int sW = 1;
|
|
||||||
const int pH = 0;
|
|
||||||
const int pW = 0;
|
|
||||||
const int dH = 1;
|
|
||||||
const int dW = 1;
|
|
||||||
const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height
|
|
||||||
const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width
|
|
||||||
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
|
||||||
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
|
||||||
// auto z('c',{bS,iD,oH,oW});
|
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
|
||||||
variableSpace->putVariable(-1, x);
|
|
||||||
// variableSpace->putVariable(1, &z);
|
|
||||||
|
|
||||||
auto block = new Context(1, variableSpace, false);
|
|
||||||
block->fillInputs({-1});
|
|
||||||
std::vector<int>* argI = block->getIArguments();
|
|
||||||
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
|
||||||
|
|
||||||
nd4j::ops::avgpool2d pooling;
|
|
||||||
Nd4jStatus status = pooling.execute(block);
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
||||||
|
|
||||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
||||||
// result->printShapeInfo();
|
|
||||||
ASSERT_TRUE(exp.isSameShape(result));
|
|
||||||
|
|
||||||
delete variableSpace;
|
|
||||||
delete block;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, Avgpool2d_test3) {
|
|
||||||
const int bS = 2;
|
|
||||||
const int iD = 1;
|
|
||||||
const int iH = 28;
|
|
||||||
const int iW = 28;
|
|
||||||
const int kH = 5;
|
|
||||||
const int kW = 5;
|
|
||||||
const int sH = 1;
|
|
||||||
const int sW = 1;
|
|
||||||
const int pH = 0;
|
|
||||||
const int pW = 0;
|
|
||||||
const int dH = 1;
|
|
||||||
const int dW = 1;
|
|
||||||
const int oH = (int) nd4j::math::nd4j_ceil<float, int>(iH * 1.f / sH);
|
|
||||||
const int oW = (int) nd4j::math::nd4j_ceil<float, int>(iW * 1.f / sW);
|
|
||||||
|
|
||||||
|
|
||||||
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
|
||||||
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
|
||||||
// auto z('c',{bS,iD,oH,oW});
|
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
|
||||||
variableSpace->putVariable(-1, x);
|
|
||||||
// variableSpace->putVariable(1, &z);
|
|
||||||
|
|
||||||
auto block = new Context(1, variableSpace, false);
|
|
||||||
block->fillInputs({-1});
|
|
||||||
std::vector<int>* argI = block->getIArguments();
|
|
||||||
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
|
||||||
|
|
||||||
nd4j::ops::avgpool2d pooling;
|
|
||||||
Nd4jStatus status = pooling.execute(block);
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
||||||
|
|
||||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
||||||
// result->printShapeInfo();
|
|
||||||
ASSERT_TRUE(exp.isSameShape(result));
|
|
||||||
|
|
||||||
delete variableSpace;
|
|
||||||
delete block;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests1, Pnormpool2d1) {
|
TEST_F(DeclarableOpsTests1, Pnormpool2d1) {
|
||||||
|
|
||||||
|
|
|
@ -1667,6 +1667,241 @@ TEST_F(DeclarableOpsTests11, Solve_Test_4) {
|
||||||
ASSERT_TRUE(exp.equalsTo(z));
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
delete res;
|
delete res;
|
||||||
}
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_1) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f, 0.2309f,
|
||||||
|
0.7271f, 0.1804f, 0.5056f, 0.8925f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {2, 2, 2}, {
|
||||||
|
1.3357621f, 0.3399364f, -0.37077796f, 0.91573375f,
|
||||||
|
0.4400987f, 0.2766527f, 0.6394467f, 0.79696566f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {true});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printBuffer("4 Solve 4x4");
|
||||||
|
// exp.printBuffer("4 Expec 4x4");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_2) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
|
0.5056f, 0.8925f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.99088347f, 1.1917052f, 1.2642528f,
|
||||||
|
0.35071516f, 0.50630623f, 0.42935497f,
|
||||||
|
-0.30013534f, -0.53690606f, -0.47959247f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {true, false});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printBuffer("4_2 Triangular_Solve 3x3");
|
||||||
|
// exp.printBuffer("4_2 Triangular_Expec 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_3) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
|
0.5056f, 0.8925f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.45400196f, 0.53174824f, 0.62064564f,
|
||||||
|
-0.79585856f, -0.82621557f, -0.87855506f,
|
||||||
|
1.1904413f, 1.3938838f, 1.3926021f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {true, true});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printBuffer("4_3 Triangular_Solve 3x3");
|
||||||
|
// exp.printBuffer("4_3 Triangular_Expec 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_4) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
|
0.5056f, 0.8925f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.8959121f, 1.6109066f, 1.7501404f,
|
||||||
|
0.49000582f, 0.66842675f, 0.5577021f,
|
||||||
|
-0.4398522f, -1.1899745f, -1.1392052f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {false});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printBuffer("4_4 Solve 3x3");
|
||||||
|
// exp.printBuffer("4_4 Expec 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_5) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
|
0.5056f, 0.8925f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
1.5504692f, 1.8953944f, 2.2765768f,
|
||||||
|
0.03399149f, 0.2883001f, 0.5377323f,
|
||||||
|
-0.8774802f, -1.2155888f, -1.8049058f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {true, true});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
// z->printBuffer("4_5 Solve 3x3");
|
||||||
|
// exp.printBuffer("4_5 Expec 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_6) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7788f, 0.8012f, 0.7244f,
|
||||||
|
0.2309f, 0.7271f, 0.1804f,
|
||||||
|
0.5056f, 0.8925f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.99088347f, 1.1917052f, 1.2642528f,
|
||||||
|
-0.426483f, -0.42840624f, -0.5622601f,
|
||||||
|
0.01692283f, -0.04538865f, -0.09868701f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {false, true});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
z->printBuffer("4_6 Solve 3x3");
|
||||||
|
exp.printBuffer("4_6 Expec 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests11, Solve_Test_4_7) {
|
||||||
|
|
||||||
|
auto a = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
// 0.7788f, 0.2309f, 0.5056f,
|
||||||
|
// 0.8012f, 0.7271f, 0.8925f,
|
||||||
|
// 0.7244f, 0.1804f, 0.5461f
|
||||||
|
|
||||||
|
0.7788f, 0.2309f, 0.5056f,
|
||||||
|
0.8012f, 0.7271f, 0.8925f,
|
||||||
|
0.7244f, 0.1804f, 0.5461f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto b = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.7717f, 0.9281f, 0.9846f,
|
||||||
|
0.4838f, 0.6433f, 0.6041f,
|
||||||
|
0.6501f, 0.7612f, 0.7605f
|
||||||
|
});
|
||||||
|
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3, 3}, {
|
||||||
|
0.99088347f, 1.1917052f, 1.2642528f,
|
||||||
|
-0.426483f, -0.42840624f, -0.5622601f,
|
||||||
|
0.01692283f, -0.04538865f, -0.09868701f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::triangular_solve op;
|
||||||
|
|
||||||
|
auto res = op.evaluate({&a, &b}, {true, false});
|
||||||
|
ASSERT_EQ(res->status(), ND4J_STATUS_OK);
|
||||||
|
auto z = res->at(0);
|
||||||
|
|
||||||
|
z->printBuffer("4_7 Solve 3x3");
|
||||||
|
exp.printBuffer("4_7 Expec 3x3");
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
delete res;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests11, Solve_Test_5) {
|
TEST_F(DeclarableOpsTests11, Solve_Test_5) {
|
||||||
|
|
|
@ -360,7 +360,6 @@ TEST_F(DeclarableOpsTests4, avgpool2d_12) {
|
||||||
917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5,
|
917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5,
|
||||||
1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5});
|
1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5});
|
||||||
input.linspace(1.);
|
input.linspace(1.);
|
||||||
input.syncToDevice();
|
|
||||||
|
|
||||||
nd4j::ops::avgpool2d op;
|
nd4j::ops::avgpool2d op;
|
||||||
auto results = op.evaluate({&input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat});
|
auto results = op.evaluate({&input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat});
|
||||||
|
@ -377,6 +376,160 @@ TEST_F(DeclarableOpsTests4, avgpool2d_12) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests4, avgpool2d_13) {
|
||||||
|
|
||||||
|
const int bS = 2; // batch size
|
||||||
|
const int iD = 1; // input depth (number of picture channels, for example rgb=3)
|
||||||
|
const int iH = 28; // picture height in pixels
|
||||||
|
const int iW = 28; // picture width in pixels
|
||||||
|
const int kH = 5; // kernel height in pixels
|
||||||
|
const int kW = 5; // kernel width in pixels
|
||||||
|
const int sH = 1; // stride step in horizontal direction
|
||||||
|
const int sW = 1; // stride step in vertical direction
|
||||||
|
const int pH = 0; // padding height
|
||||||
|
const int pW = 0; // padding width
|
||||||
|
const int dH = 2; // dilation height
|
||||||
|
const int dW = 2; // dilation width
|
||||||
|
const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height
|
||||||
|
const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
||||||
|
// auto z('c',{bS,iD,oH,oW});
|
||||||
|
|
||||||
|
auto variableSpace = new VariableSpace();
|
||||||
|
variableSpace->putVariable(-1, x);
|
||||||
|
// variableSpace->putVariable(1, &z);
|
||||||
|
|
||||||
|
auto block = new Context(1, variableSpace, false);
|
||||||
|
block->fillInputs({-1});
|
||||||
|
std::vector<int>* argI = block->getIArguments();
|
||||||
|
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
|
|
||||||
|
nd4j::ops::avgpool2d pooling;
|
||||||
|
Nd4jStatus status = pooling.execute(block);
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
|
||||||
|
|
||||||
|
delete variableSpace;
|
||||||
|
delete block;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests4, avgpool2d_14) {
|
||||||
|
const int bS = 2;
|
||||||
|
const int iD = 1;
|
||||||
|
const int iH = 28;
|
||||||
|
const int iW = 28;
|
||||||
|
const int kH = 5;
|
||||||
|
const int kW = 5;
|
||||||
|
const int sH = 1;
|
||||||
|
const int sW = 1;
|
||||||
|
const int pH = 0;
|
||||||
|
const int pW = 0;
|
||||||
|
const int dH = 1;
|
||||||
|
const int dW = 1;
|
||||||
|
const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height
|
||||||
|
const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width
|
||||||
|
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
||||||
|
// auto z('c',{bS,iD,oH,oW});
|
||||||
|
|
||||||
|
auto variableSpace = new VariableSpace();
|
||||||
|
variableSpace->putVariable(-1, x);
|
||||||
|
// variableSpace->putVariable(1, &z);
|
||||||
|
|
||||||
|
auto block = new Context(1, variableSpace, false);
|
||||||
|
block->fillInputs({-1});
|
||||||
|
std::vector<int>* argI = block->getIArguments();
|
||||||
|
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
|
|
||||||
|
nd4j::ops::avgpool2d pooling;
|
||||||
|
Nd4jStatus status = pooling.execute(block);
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||||
|
// result->printShapeInfo();
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
|
||||||
|
delete variableSpace;
|
||||||
|
delete block;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests4, Avgpool2d_test15) {
|
||||||
|
const int bS = 2;
|
||||||
|
const int iD = 1;
|
||||||
|
const int iH = 28;
|
||||||
|
const int iW = 28;
|
||||||
|
const int kH = 5;
|
||||||
|
const int kW = 5;
|
||||||
|
const int sH = 1;
|
||||||
|
const int sW = 1;
|
||||||
|
const int pH = 0;
|
||||||
|
const int pW = 0;
|
||||||
|
const int dH = 1;
|
||||||
|
const int dW = 1;
|
||||||
|
const int oH = (int) nd4j::math::nd4j_ceil<float, int>(iH * 1.f / sH);
|
||||||
|
const int oW = (int) nd4j::math::nd4j_ceil<float, int>(iW * 1.f / sW);
|
||||||
|
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create_<float>('c', {bS,iD,iH,iW});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c',{bS,iD,oH,oW});
|
||||||
|
// auto z('c',{bS,iD,oH,oW});
|
||||||
|
|
||||||
|
auto variableSpace = new VariableSpace();
|
||||||
|
variableSpace->putVariable(-1, x);
|
||||||
|
// variableSpace->putVariable(1, &z);
|
||||||
|
|
||||||
|
auto block = new Context(1, variableSpace, false);
|
||||||
|
block->fillInputs({-1});
|
||||||
|
std::vector<int>* argI = block->getIArguments();
|
||||||
|
*argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode;
|
||||||
|
|
||||||
|
nd4j::ops::avgpool2d pooling;
|
||||||
|
Nd4jStatus status = pooling.execute(block);
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||||
|
|
||||||
|
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
||||||
|
// result->printShapeInfo();
|
||||||
|
ASSERT_TRUE(exp.isSameShape(result));
|
||||||
|
|
||||||
|
delete variableSpace;
|
||||||
|
delete block;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests4, avgpool2d_16) {
|
||||||
|
|
||||||
|
int bS=2, iH=4,iW=4, iC=2, kH=2,kW=2, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1;
|
||||||
|
int oH=2,oW=2;
|
||||||
|
int paddingMode = 1; // 1-SAME, 0-VALID
|
||||||
|
int dataFormat = 1; // 1-NHWC, 0-NDHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray output('f', {bS, oH, oW, iC}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expected('c', {bS, oH, oW, iC}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(1.);
|
||||||
|
|
||||||
|
nd4j::ops::avgpool2d op;
|
||||||
|
auto status = op.execute({&input}, {&output}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
|
||||||
|
// output.printBuffer();
|
||||||
|
//expected.printIndexedBuffer("expected");
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests4, biasadd_1) {
|
TEST_F(DeclarableOpsTests4, biasadd_1) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 3, 3, 2});
|
auto x = NDArrayFactory::create<double>('c', {2, 3, 3, 2});
|
||||||
|
|
|
@ -802,6 +802,66 @@ TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, Test_SequenceMask_3) {
|
||||||
|
auto input = NDArrayFactory::create<int>('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8});
|
||||||
|
auto exp = NDArrayFactory::create<int>('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
|
||||||
|
1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
|
1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0});
|
||||||
|
|
||||||
|
nd4j::ops::sequence_mask op;
|
||||||
|
auto result = op.evaluate({&input}, {nd4j::DataType::INT32});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printBuffer("Output");
|
||||||
|
// z->printShapeInfo("Shape");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, Test_SequenceMask_4) {
|
||||||
|
auto input = NDArrayFactory::create<int>({1, 3, 2});
|
||||||
|
auto maxLen = NDArrayFactory::create<int>(5);
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3,5}, {
|
||||||
|
1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::sequence_mask op;
|
||||||
|
auto result = op.evaluate({&input, &maxLen}, {nd4j::DataType::FLOAT32});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printBuffer("Output");
|
||||||
|
// z->printShapeInfo("Shape");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests7, Test_SequenceMask_5) {
|
||||||
|
auto input = NDArrayFactory::create<int>({1, 3, 2});
|
||||||
|
auto exp = NDArrayFactory::create<float>('c', {3,5}, {
|
||||||
|
1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f
|
||||||
|
});
|
||||||
|
|
||||||
|
nd4j::ops::sequence_mask op;
|
||||||
|
auto result = op.evaluate({&input}, {5, (int)nd4j::DataType::FLOAT32});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
// z->printBuffer("Output");
|
||||||
|
// z->printShapeInfo("Shape");
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests7, TestSegmentMax_1) {
|
TEST_F(DeclarableOpsTests7, TestSegmentMax_1) {
|
||||||
auto x = NDArrayFactory::create<double>({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.});
|
auto x = NDArrayFactory::create<double>({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.});
|
||||||
|
|
|
@ -422,50 +422,38 @@ TEST_F(PlaygroundTests, my) {
|
||||||
delete variableSpace;
|
delete variableSpace;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#include<ops/declarable/helpers/batchnorm.h>
|
|
||||||
|
|
||||||
TEST_F(PlaygroundTests, my) {
|
TEST_F(PlaygroundTests, my) {
|
||||||
|
|
||||||
const int N = 10000;
|
int N = 100;
|
||||||
const Nd4jLong dim0(128), dim1(128), dim2(128);
|
int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
||||||
|
int oH=128,oW=128;
|
||||||
|
|
||||||
NDArray input('c', {dim0,dim1,dim2}, nd4j::DataType::DOUBLE);
|
int paddingMode = 1; // 1-SAME, 0-VALID;
|
||||||
NDArray mean('c', {dim1}, nd4j::DataType::DOUBLE);
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
NDArray variance('c', {dim1}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray gamma('c', {dim1}, nd4j::DataType::DOUBLE);
|
|
||||||
NDArray beta ('c', {dim1}, nd4j::DataType::DOUBLE);
|
|
||||||
|
|
||||||
NDArray output('c', {dim0,dim1,dim2}, nd4j::DataType::DOUBLE);
|
// NDArray input('c', {bS, iC, iH, iW}, nd4j::DataType::FLOAT32);
|
||||||
|
// NDArray output('c', {bS, oC, oH, oW}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray input('c', {bS, iH, iW, iC}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray output('c', {bS, oH, oW, oC}, nd4j::DataType::FLOAT32);
|
||||||
|
// NDArray weights('c', {kH, kW, iC, oC}, nd4j::DataType::FLOAT32); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||||
|
NDArray weights('c', {oC, iC, kH, kW}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray bias('c', {oC}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
input.linspace(-100, 0.1);
|
input = 5.;
|
||||||
mean.linspace(-50, 0.15);
|
weights = 3.;
|
||||||
variance.linspace(-5, 0.2);
|
bias = 1.;
|
||||||
gamma = 1.5;
|
|
||||||
beta = -2.5;
|
|
||||||
|
|
||||||
// warm up
|
nd4j::ops::conv2d op;
|
||||||
ops::helpers::batchnorm(&input, &mean, &variance, &gamma, &beta, &output, {1}, 1e-5);
|
auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
|
|
||||||
auto timeStart = std::chrono::system_clock::now();
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
for (int i = 0; i < N; ++i)
|
for (int i = 0; i < N; ++i)
|
||||||
ops::helpers::batchnorm(&input, &mean, &variance, &gamma, &beta, &output, {1}, 1e-5);
|
err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||||
|
|
||||||
auto timeEnd = std::chrono::system_clock::now();
|
auto timeEnd = std::chrono::system_clock::now();
|
||||||
auto time = std::chrono::duration_cast<std::chrono::microseconds> ((timeEnd - timeStart) / N).count();
|
auto time = std::chrono::duration_cast<std::chrono::microseconds> ((timeEnd - timeStart) / N).count();
|
||||||
|
|
||||||
printf("time: %li \n", time);
|
printf("time: %i \n", time);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -780,7 +780,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
throw new IllegalArgumentException("Unable to create array of length " + length);
|
throw new IllegalArgumentException("Unable to create array of length " + length);
|
||||||
float[] ret = new float[(int) length];
|
float[] ret = new float[(int) length];
|
||||||
for (int i = 0; i < length; i++)
|
for (int i = 0; i < length; i++)
|
||||||
ret[i] = getFloat(i);
|
ret[i] = getFloatUnsynced(i);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -790,7 +790,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
throw new IllegalArgumentException("Unable to create array of length " + length);
|
throw new IllegalArgumentException("Unable to create array of length " + length);
|
||||||
double[] ret = new double[(int) length];
|
double[] ret = new double[(int) length];
|
||||||
for (int i = 0; i < length; i++)
|
for (int i = 0; i < length; i++)
|
||||||
ret[i] = getDouble(i);
|
ret[i] = getDoubleUnsynced(i);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -800,7 +800,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
throw new IllegalArgumentException("Unable to create array of length " + length);
|
throw new IllegalArgumentException("Unable to create array of length " + length);
|
||||||
int[] ret = new int[(int) length];
|
int[] ret = new int[(int) length];
|
||||||
for (int i = 0; i < length; i++)
|
for (int i = 0; i < length; i++)
|
||||||
ret[i] = getInt(i);
|
ret[i] = getIntUnsynced(i);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -810,7 +810,7 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
throw new IllegalArgumentException("Unable to create array of length " + length);
|
throw new IllegalArgumentException("Unable to create array of length " + length);
|
||||||
long[] ret = new long[(int) length];
|
long[] ret = new long[(int) length];
|
||||||
for (int i = 0; i < length; i++)
|
for (int i = 0; i < length; i++)
|
||||||
ret[i] = getLong(i);
|
ret[i] = getLongUnsynced(i);
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1662,6 +1662,11 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
protected abstract double getDoubleUnsynced(long index);
|
||||||
|
protected abstract float getFloatUnsynced(long index);
|
||||||
|
protected abstract long getLongUnsynced(long index);
|
||||||
|
protected abstract int getIntUnsynced(long index);
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void write(DataOutputStream out) throws IOException {
|
public void write(DataOutputStream out) throws IOException {
|
||||||
out.writeUTF(allocationMode.name());
|
out.writeUTF(allocationMode.name());
|
||||||
|
@ -1670,43 +1675,43 @@ public abstract class BaseDataBuffer implements DataBuffer {
|
||||||
switch (dataType()) {
|
switch (dataType()) {
|
||||||
case DOUBLE:
|
case DOUBLE:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeDouble(getDouble(i));
|
out.writeDouble(getDoubleUnsynced(i));
|
||||||
break;
|
break;
|
||||||
case UINT64:
|
case UINT64:
|
||||||
case LONG:
|
case LONG:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeLong(getLong(i));
|
out.writeLong(getLongUnsynced(i));
|
||||||
break;
|
break;
|
||||||
case UINT32:
|
case UINT32:
|
||||||
case INT:
|
case INT:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeInt(getInt(i));
|
out.writeInt(getIntUnsynced(i));
|
||||||
break;
|
break;
|
||||||
case UINT16:
|
case UINT16:
|
||||||
case SHORT:
|
case SHORT:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeShort((short) getInt(i));
|
out.writeShort((short) getIntUnsynced(i));
|
||||||
break;
|
break;
|
||||||
case UBYTE:
|
case UBYTE:
|
||||||
case BYTE:
|
case BYTE:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeByte((byte) getInt(i));
|
out.writeByte((byte) getIntUnsynced(i));
|
||||||
break;
|
break;
|
||||||
case BOOL:
|
case BOOL:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeByte(getInt(i) == 0 ? (byte) 0 : (byte) 1);
|
out.writeByte(getIntUnsynced(i) == 0 ? (byte) 0 : (byte) 1);
|
||||||
break;
|
break;
|
||||||
case BFLOAT16:
|
case BFLOAT16:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeShort((short) Bfloat16Indexer.fromFloat(getFloat(i)));
|
out.writeShort((short) Bfloat16Indexer.fromFloat(getFloatUnsynced(i)));
|
||||||
break;
|
break;
|
||||||
case HALF:
|
case HALF:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeShort((short) HalfIndexer.fromFloat(getFloat(i)));
|
out.writeShort((short) HalfIndexer.fromFloat(getFloatUnsynced(i)));
|
||||||
break;
|
break;
|
||||||
case FLOAT:
|
case FLOAT:
|
||||||
for (long i = 0; i < length(); i++)
|
for (long i = 0; i < length(); i++)
|
||||||
out.writeFloat(getFloat(i));
|
out.writeFloat(getFloatUnsynced(i));
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -43,7 +43,7 @@ public class DeallocatorService {
|
||||||
private Map<String, DeallocatableReference> referenceMap = new ConcurrentHashMap<>();
|
private Map<String, DeallocatableReference> referenceMap = new ConcurrentHashMap<>();
|
||||||
private List<List<ReferenceQueue<Deallocatable>>> deviceMap = new ArrayList<>();
|
private List<List<ReferenceQueue<Deallocatable>>> deviceMap = new ArrayList<>();
|
||||||
|
|
||||||
private AtomicLong counter = new AtomicLong(0);
|
private final transient AtomicLong counter = new AtomicLong(0);
|
||||||
|
|
||||||
public DeallocatorService() {
|
public DeallocatorService() {
|
||||||
// we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity
|
// we need to have at least 2 threads, but for CUDA we'd need at least numDevices threads, due to thread->device affinity
|
||||||
|
|
|
@ -153,4 +153,10 @@ public abstract class BaseOpContext implements OpContext {
|
||||||
for (int e = 0; e < arrays.length; e++)
|
for (int e = 0; e < arrays.length; e++)
|
||||||
setOutputArray(e, arrays[e]);
|
setOutputArray(e, arrays[e]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void purge() {
|
||||||
|
fastpath_in.clear();
|
||||||
|
fastpath_out.clear();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -162,4 +162,9 @@ public interface OpContext extends AutoCloseable {
|
||||||
* @param mode
|
* @param mode
|
||||||
*/
|
*/
|
||||||
void setExecutionMode(ExecutionMode mode);
|
void setExecutionMode(ExecutionMode mode);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method removes all in/out arrays from this OpContext
|
||||||
|
*/
|
||||||
|
void purge();
|
||||||
}
|
}
|
||||||
|
|
|
@ -210,4 +210,24 @@ public class CompressedDataBuffer extends BaseDataBuffer {
|
||||||
public DataBuffer reallocate(long length) {
|
public DataBuffer reallocate(long length) {
|
||||||
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
|
throw new UnsupportedOperationException("This method isn't supported by CompressedDataBuffer");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected double getDoubleUnsynced(long index) {
|
||||||
|
return super.getDouble(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected float getFloatUnsynced(long index) {
|
||||||
|
return super.getFloat(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected long getLongUnsynced(long index) {
|
||||||
|
return super.getLong(index);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
protected int getIntUnsynced(long index) {
|
||||||
|
return super.getInt(index);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1161,6 +1161,7 @@ public interface NativeOps {
|
||||||
void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow);
|
void ctxAllowHelpers(OpaqueContext ptr, boolean reallyAllow);
|
||||||
void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
|
void ctxSetExecutionMode(OpaqueContext ptr, int execMode);
|
||||||
void ctxShapeFunctionOverride(OpaqueContext ptr, boolean reallyOverride);
|
void ctxShapeFunctionOverride(OpaqueContext ptr, boolean reallyOverride);
|
||||||
|
void ctxPurge(OpaqueContext ptr);
|
||||||
void deleteGraphContext(OpaqueContext ptr);
|
void deleteGraphContext(OpaqueContext ptr);
|
||||||
|
|
||||||
OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed);
|
OpaqueRandomGenerator createRandomGenerator(long rootSeed, long nodeSeed);
|
||||||
|
|
|
@ -60,7 +60,7 @@
|
||||||
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
Maximum heap size was set to 6g, as a minimum required value for tests run.
|
||||||
Depending on a build machine, default value is not always enough.
|
Depending on a build machine, default value is not always enough.
|
||||||
-->
|
-->
|
||||||
<argLine>-Ddtype=float -Xmx8g</argLine>
|
<argLine>-Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g</argLine>
|
||||||
</configuration>
|
</configuration>
|
||||||
</plugin>
|
</plugin>
|
||||||
<plugin>
|
<plugin>
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue