Merge pull request #9185 from eclipse/sa_tvm
Add nd4j-tvm module with initial inference support using TVMmaster
commit
aded258340
|
@ -23,6 +23,10 @@
|
||||||
#ifndef LIBND4J_BLAS_HELPER_H
|
#ifndef LIBND4J_BLAS_HELPER_H
|
||||||
#define LIBND4J_BLAS_HELPER_H
|
#define LIBND4J_BLAS_HELPER_H
|
||||||
|
|
||||||
|
// work around conflict with OpenBLAS
|
||||||
|
struct bfloat16;
|
||||||
|
#define BFLOAT16 BFLOAT16
|
||||||
|
|
||||||
#include <system/pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <types/float16.h>
|
#include <types/float16.h>
|
||||||
#include <cblas.h>
|
#include <cblas.h>
|
||||||
|
|
|
@ -23,6 +23,10 @@
|
||||||
#ifndef LIBND4J_GEMM_H
|
#ifndef LIBND4J_GEMM_H
|
||||||
#define LIBND4J_GEMM_H
|
#define LIBND4J_GEMM_H
|
||||||
|
|
||||||
|
// work around conflict with OpenBLAS
|
||||||
|
struct bfloat16;
|
||||||
|
#define BFLOAT16 BFLOAT16
|
||||||
|
|
||||||
#include <cblas.h>
|
#include <cblas.h>
|
||||||
#include <math/templatemath.h>
|
#include <math/templatemath.h>
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
|
|
|
@ -139,6 +139,13 @@ import java.util.Scanner;
|
||||||
"ops/declarable/headers/loss.h",
|
"ops/declarable/headers/loss.h",
|
||||||
"ops/declarable/headers/datatypes.h",
|
"ops/declarable/headers/datatypes.h",
|
||||||
"ops/declarable/headers/third_party.h",
|
"ops/declarable/headers/third_party.h",
|
||||||
|
"openblas_config.h",
|
||||||
|
"cblas.h",
|
||||||
|
"lapacke_config.h",
|
||||||
|
"lapacke_mangling.h",
|
||||||
|
"lapack.h",
|
||||||
|
"lapacke.h",
|
||||||
|
"lapacke_utils.h",
|
||||||
"cnpy/cnpy.h"
|
"cnpy/cnpy.h"
|
||||||
},
|
},
|
||||||
compiler = {"cpp11", "nowarnings"},
|
compiler = {"cpp11", "nowarnings"},
|
||||||
|
@ -166,6 +173,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled {
|
||||||
public void map(InfoMap infoMap) {
|
public void map(InfoMap infoMap) {
|
||||||
infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE",
|
infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE",
|
||||||
"_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations())
|
"_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations())
|
||||||
|
.put(new Info("openblas_config.h", "cblas.h", "lapacke_config.h", "lapacke_mangling.h", "lapack.h", "lapacke.h", "lapacke_utils.h").skip())
|
||||||
.put(new Info("NativeOps.h", "build_info.h").objectify())
|
.put(new Info("NativeOps.h", "build_info.h").objectify())
|
||||||
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
|
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
|
||||||
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
|
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
|
||||||
|
|
|
@ -38,7 +38,7 @@
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||||
<maven.compiler.source>1.8</maven.compiler.source>
|
<maven.compiler.source>1.8</maven.compiler.source>
|
||||||
<maven.compiler.target>1.8</maven.compiler.target>
|
<maven.compiler.target>1.8</maven.compiler.target>
|
||||||
<onnxruntime.version>1.4.0</onnxruntime.version>
|
<onnxruntime.version>1.6.0</onnxruntime.version>
|
||||||
<onnxruntime.javacpp.version>${onnxruntime.version}-${javacpp.version}</onnxruntime.javacpp.version>
|
<onnxruntime.javacpp.version>${onnxruntime.version}-${javacpp.version}</onnxruntime.javacpp.version>
|
||||||
</properties>
|
</properties>
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
|
||||||
|
<!--
|
||||||
|
~ /* ******************************************************************************
|
||||||
|
~ *
|
||||||
|
~ *
|
||||||
|
~ * This program and the accompanying materials are made available under the
|
||||||
|
~ * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
~ *
|
||||||
|
~ * See the NOTICE file distributed with this work for additional
|
||||||
|
~ * information regarding copyright ownership.
|
||||||
|
~ * Unless required by applicable law or agreed to in writing, software
|
||||||
|
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
~ * License for the specific language governing permissions and limitations
|
||||||
|
~ * under the License.
|
||||||
|
~ *
|
||||||
|
~ * SPDX-License-Identifier: Apache-2.0
|
||||||
|
~ ******************************************************************************/
|
||||||
|
-->
|
||||||
|
|
||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||||
|
<parent>
|
||||||
|
<artifactId>nd4j</artifactId>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<version>1.0.0-SNAPSHOT</version>
|
||||||
|
</parent>
|
||||||
|
<modelVersion>4.0.0</modelVersion>
|
||||||
|
|
||||||
|
<artifactId>nd4j-tvm</artifactId>
|
||||||
|
<name>nd4j-tvm</name>
|
||||||
|
|
||||||
|
<properties>
|
||||||
|
<tvm.version>0.7.0</tvm.version>
|
||||||
|
<tvm.javacpp.version>${tvm.version}-${javacpp-presets.version}</tvm.javacpp.version>
|
||||||
|
</properties>
|
||||||
|
|
||||||
|
<dependencies>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-api</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.bytedeco</groupId>
|
||||||
|
<artifactId>mkl-platform-redist</artifactId>
|
||||||
|
<version>${mkl.version}-${javacpp-presets.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.bytedeco</groupId>
|
||||||
|
<artifactId>tvm-platform</artifactId>
|
||||||
|
<version>${tvm.javacpp.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.bytedeco</groupId>
|
||||||
|
<artifactId>tvm</artifactId>
|
||||||
|
<version>${tvm.javacpp.version}</version>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>junit</groupId>
|
||||||
|
<artifactId>junit</artifactId>
|
||||||
|
</dependency>
|
||||||
|
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-native</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
<scope>test</scope>
|
||||||
|
</dependency>
|
||||||
|
</dependencies>
|
||||||
|
<profiles>
|
||||||
|
<profile>
|
||||||
|
<id>testresources</id>
|
||||||
|
</profile>
|
||||||
|
</profiles>
|
||||||
|
</project>
|
|
@ -0,0 +1,164 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
package org.nd4j.tvm.runner;
|
||||||
|
|
||||||
|
import java.io.Closeable;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import lombok.Builder;
|
||||||
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.bytedeco.javacpp.*;
|
||||||
|
import org.bytedeco.tvm.*;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.tvm.util.TVMUtils;
|
||||||
|
|
||||||
|
import static org.bytedeco.tvm.global.tvm_runtime.*;
|
||||||
|
import static org.nd4j.tvm.util.TVMUtils.*;
|
||||||
|
|
||||||
|
@Slf4j
|
||||||
|
public class TvmRunner implements Closeable {
|
||||||
|
private static DLContext ctx;
|
||||||
|
private Module modFactory;
|
||||||
|
private TVMValue values;
|
||||||
|
private IntPointer codes;
|
||||||
|
private TVMArgsSetter setter;
|
||||||
|
private TVMRetValue rv;
|
||||||
|
private Module gmod;
|
||||||
|
private PackedFunc getNumInputs;
|
||||||
|
private PackedFunc getNumOutputs;
|
||||||
|
private PackedFunc setInput;
|
||||||
|
private PackedFunc getOutput;
|
||||||
|
private PackedFunc run;
|
||||||
|
|
||||||
|
@Builder
|
||||||
|
public TvmRunner(String modelUri) {
|
||||||
|
if (ctx == null) {
|
||||||
|
ctx = new DLContext().device_type(kDLCPU).device_id(0);
|
||||||
|
ctx.retainReference();
|
||||||
|
}
|
||||||
|
|
||||||
|
// create the runtime module
|
||||||
|
try (PointerScope scope = new PointerScope()) {
|
||||||
|
modFactory = Module.LoadFromFile(modelUri);
|
||||||
|
values = new TVMValue(2);
|
||||||
|
codes = new IntPointer(2);
|
||||||
|
setter = new TVMArgsSetter(values, codes);
|
||||||
|
setter.apply(0, ctx);
|
||||||
|
rv = new TVMRetValue();
|
||||||
|
modFactory.GetFunction("default").CallPacked(new TVMArgs(values, codes, 1), rv);
|
||||||
|
gmod = rv.asModule();
|
||||||
|
getNumInputs = gmod.GetFunction("get_num_inputs");
|
||||||
|
getNumOutputs = gmod.GetFunction("get_num_outputs");
|
||||||
|
setInput = gmod.GetFunction("set_input");
|
||||||
|
getOutput = gmod.GetFunction("get_output");
|
||||||
|
run = gmod.GetFunction("run");
|
||||||
|
// retain the session reference to prevent pre emptive release of the session.
|
||||||
|
modFactory.retainReference();
|
||||||
|
values.retainReference();
|
||||||
|
codes.retainReference();
|
||||||
|
setter.retainReference();
|
||||||
|
rv.retainReference();
|
||||||
|
gmod.retainReference();
|
||||||
|
getNumInputs.retainReference();
|
||||||
|
getNumOutputs.retainReference();
|
||||||
|
setInput.retainReference();
|
||||||
|
getOutput.retainReference();
|
||||||
|
run.retainReference();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() {
|
||||||
|
if (run != null) {
|
||||||
|
run.releaseReference();
|
||||||
|
}
|
||||||
|
if (getOutput != null) {
|
||||||
|
getOutput.releaseReference();
|
||||||
|
}
|
||||||
|
if (setInput != null) {
|
||||||
|
setInput.releaseReference();
|
||||||
|
}
|
||||||
|
if (getNumOutputs != null) {
|
||||||
|
getNumOutputs.releaseReference();
|
||||||
|
}
|
||||||
|
if (getNumInputs != null) {
|
||||||
|
getNumInputs.releaseReference();
|
||||||
|
}
|
||||||
|
if (gmod != null) {
|
||||||
|
gmod.releaseReference();
|
||||||
|
}
|
||||||
|
if (rv != null) {
|
||||||
|
rv.releaseReference();
|
||||||
|
}
|
||||||
|
if (setter != null) {
|
||||||
|
setter.releaseReference();
|
||||||
|
}
|
||||||
|
if (codes != null) {
|
||||||
|
codes.releaseReference();
|
||||||
|
}
|
||||||
|
if (values != null) {
|
||||||
|
values.releaseReference();
|
||||||
|
}
|
||||||
|
if (modFactory != null) {
|
||||||
|
modFactory.releaseReference();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Execute the {@link #run} function
|
||||||
|
* using the given input {@link Map}
|
||||||
|
* @param input the input map
|
||||||
|
* @return a map of the names of the ndarrays
|
||||||
|
*/
|
||||||
|
public Map<String,INDArray> exec(Map<String,INDArray> input) {
|
||||||
|
try (PointerScope scope = new PointerScope()) {
|
||||||
|
getNumInputs.CallPacked(new TVMArgs(values, codes, 0), rv);
|
||||||
|
long numInputNodes = rv.asLong();
|
||||||
|
getNumOutputs.CallPacked(new TVMArgs(values, codes, 0), rv);
|
||||||
|
long numOutputNodes = rv.asLong();
|
||||||
|
|
||||||
|
// set the right input
|
||||||
|
for (Map.Entry<String,INDArray> e : input.entrySet()) {
|
||||||
|
String name = e.getKey();
|
||||||
|
INDArray arr = e.getValue();
|
||||||
|
DLTensor inputTensor = getTensor(arr, ctx);
|
||||||
|
Preconditions.checkState(inputTensor != null,"Input must be a tensor.");
|
||||||
|
setter.apply(0, new BytePointer(name));
|
||||||
|
setter.apply(1, inputTensor);
|
||||||
|
setInput.CallPacked(new TVMArgs(values, codes, 2), rv);
|
||||||
|
}
|
||||||
|
|
||||||
|
// run the code
|
||||||
|
run.CallPacked(new TVMArgs(values, codes, 0), rv);
|
||||||
|
|
||||||
|
Map<String, INDArray> ret = new LinkedHashMap<>();
|
||||||
|
|
||||||
|
// get the output
|
||||||
|
for (int i = 0; i < numOutputNodes; i++) {
|
||||||
|
setter.apply(0, i);
|
||||||
|
getOutput.CallPacked(new TVMArgs(values, codes, 1), rv);
|
||||||
|
DLTensor outputTensor = rv.asDLTensor();
|
||||||
|
ret.put(Integer.toString(i), getArray(outputTensor));
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,233 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
package org.nd4j.tvm.util;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.*;
|
||||||
|
import org.bytedeco.javacpp.indexer.*;
|
||||||
|
import org.bytedeco.tvm.*;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import static org.bytedeco.tvm.global.tvm_runtime.*;
|
||||||
|
import static org.nd4j.linalg.api.buffer.DataType.*;
|
||||||
|
|
||||||
|
public class TVMUtils {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return a {@link DataType}
|
||||||
|
* for the tvm data type
|
||||||
|
* @param dataType the equivalent nd4j data type
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static DataType dataTypeForTvmType(DLDataType dataType) {
|
||||||
|
if(dataType.code() == kDLInt && dataType.bits() == 8) {
|
||||||
|
return INT8;
|
||||||
|
} else if(dataType.code() == kDLInt && dataType.bits() == 16) {
|
||||||
|
return INT16;
|
||||||
|
} else if(dataType.code() == kDLInt && dataType.bits() == 32) {
|
||||||
|
return INT32;
|
||||||
|
} else if(dataType.code() == kDLInt && dataType.bits() == 64) {
|
||||||
|
return INT64;
|
||||||
|
} else if(dataType.code() == kDLUInt && dataType.bits() == 8) {
|
||||||
|
return UINT8;
|
||||||
|
} else if(dataType.code() == kDLUInt && dataType.bits() == 16) {
|
||||||
|
return UINT16;
|
||||||
|
} else if(dataType.code() == kDLUInt && dataType.bits() == 32) {
|
||||||
|
return UINT32;
|
||||||
|
} else if(dataType.code() == kDLUInt && dataType.bits() == 64) {
|
||||||
|
return UINT64;
|
||||||
|
} else if(dataType.code() == kDLFloat && dataType.bits() == 16) {
|
||||||
|
return FLOAT16;
|
||||||
|
} else if(dataType.code() == kDLFloat && dataType.bits() == 32) {
|
||||||
|
return FLOAT;
|
||||||
|
} else if(dataType.code() == kDLFloat && dataType.bits() == 64) {
|
||||||
|
return DOUBLE;
|
||||||
|
} else if(dataType.code() == kDLBfloat && dataType.bits() == 16) {
|
||||||
|
return BFLOAT16;
|
||||||
|
} else
|
||||||
|
throw new IllegalArgumentException("Illegal data type code " + dataType.code() + " with bits " + dataType.bits());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert the tvm type for the given data type
|
||||||
|
* @param dataType
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static DLDataType tvmTypeForDataType(DataType dataType) {
|
||||||
|
if(dataType == INT8) {
|
||||||
|
return new DLDataType().code((byte)kDLInt).bits((byte)8).lanes((short)1);
|
||||||
|
} else if(dataType == INT16) {
|
||||||
|
return new DLDataType().code((byte)kDLInt).bits((byte)16).lanes((short)1);
|
||||||
|
} else if(dataType == INT32) {
|
||||||
|
return new DLDataType().code((byte)kDLInt).bits((byte)32).lanes((short)1);
|
||||||
|
} else if(dataType == INT64) {
|
||||||
|
return new DLDataType().code((byte)kDLInt).bits((byte)64).lanes((short)1);
|
||||||
|
} else if(dataType == UINT8) {
|
||||||
|
return new DLDataType().code((byte)kDLUInt).bits((byte)8).lanes((short)1);
|
||||||
|
} else if(dataType == UINT16) {
|
||||||
|
return new DLDataType().code((byte)kDLUInt).bits((byte)16).lanes((short)1);
|
||||||
|
} else if(dataType == UINT32) {
|
||||||
|
return new DLDataType().code((byte)kDLUInt).bits((byte)32).lanes((short)1);
|
||||||
|
} else if(dataType == UINT64) {
|
||||||
|
return new DLDataType().code((byte)kDLUInt).bits((byte)64).lanes((short)1);
|
||||||
|
} else if(dataType == FLOAT16) {
|
||||||
|
return new DLDataType().code((byte)kDLFloat).bits((byte)16).lanes((short)1);
|
||||||
|
} else if(dataType == FLOAT) {
|
||||||
|
return new DLDataType().code((byte)kDLFloat).bits((byte)32).lanes((short)1);
|
||||||
|
} else if(dataType == DOUBLE) {
|
||||||
|
return new DLDataType().code((byte)kDLFloat).bits((byte)64).lanes((short)1);
|
||||||
|
} else if(dataType == BFLOAT16) {
|
||||||
|
return new DLDataType().code((byte)kDLBfloat).bits((byte)16).lanes((short)1);
|
||||||
|
} else
|
||||||
|
throw new IllegalArgumentException("Illegal data type " + dataType);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert an tvm {@link DLTensor}
|
||||||
|
* in to an {@link INDArray}
|
||||||
|
* @param value the tensor to convert
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static INDArray getArray(DLTensor value) {
|
||||||
|
DataType dataType = dataTypeForTvmType(value.dtype());
|
||||||
|
LongPointer shape = value.shape();
|
||||||
|
LongPointer stride = value.strides();
|
||||||
|
long[] shapeConvert;
|
||||||
|
if(shape != null) {
|
||||||
|
shapeConvert = new long[value.ndim()];
|
||||||
|
shape.get(shapeConvert);
|
||||||
|
} else {
|
||||||
|
shapeConvert = new long[]{1};
|
||||||
|
}
|
||||||
|
long[] strideConvert;
|
||||||
|
if(stride != null) {
|
||||||
|
strideConvert = new long[value.ndim()];
|
||||||
|
stride.get(strideConvert);
|
||||||
|
} else {
|
||||||
|
strideConvert = Nd4j.getStrides(shapeConvert);
|
||||||
|
}
|
||||||
|
long size = 1;
|
||||||
|
for (int i = 0; i < shapeConvert.length; i++) {
|
||||||
|
size *= shapeConvert[i];
|
||||||
|
}
|
||||||
|
size *= value.dtype().bits() / 8;
|
||||||
|
|
||||||
|
DataBuffer getBuffer = getDataBuffer(value,size);
|
||||||
|
Preconditions.checkState(dataType.equals(getBuffer.dataType()),"Data type must be equivalent as specified by the tvm metadata.");
|
||||||
|
return Nd4j.create(getBuffer,shapeConvert,strideConvert,0);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get an tvm tensor from an ndarray.
|
||||||
|
* @param ndArray the ndarray to get the value from
|
||||||
|
* @param ctx the {@link DLContext} to use.
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public static DLTensor getTensor(INDArray ndArray, DLContext ctx) {
|
||||||
|
DLTensor ret = new DLTensor();
|
||||||
|
ret.data(ndArray.data().pointer());
|
||||||
|
ret.ctx(ctx);
|
||||||
|
ret.ndim(ndArray.rank());
|
||||||
|
ret.dtype(tvmTypeForDataType(ndArray.dataType()));
|
||||||
|
ret.shape(new LongPointer(ndArray.shape()));
|
||||||
|
ret.strides(new LongPointer(ndArray.stride()));
|
||||||
|
ret.byte_offset(ndArray.offset());
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the data buffer from the given value
|
||||||
|
* @param tens the values to get
|
||||||
|
* @return the equivalent data buffer
|
||||||
|
*/
|
||||||
|
public static DataBuffer getDataBuffer(DLTensor tens, long size) {
|
||||||
|
DataBuffer buffer = null;
|
||||||
|
DataType type = dataTypeForTvmType(tens.dtype());
|
||||||
|
switch (type) {
|
||||||
|
case BYTE:
|
||||||
|
BytePointer pInt8 = new BytePointer(tens.data()).capacity(size);
|
||||||
|
Indexer int8Indexer = ByteIndexer.create(pInt8);
|
||||||
|
buffer = Nd4j.createBuffer(pInt8, type, size, int8Indexer);
|
||||||
|
break;
|
||||||
|
case SHORT:
|
||||||
|
ShortPointer pInt16 = new ShortPointer(tens.data()).capacity(size);
|
||||||
|
Indexer int16Indexer = ShortIndexer.create(pInt16);
|
||||||
|
buffer = Nd4j.createBuffer(pInt16, type, size, int16Indexer);
|
||||||
|
break;
|
||||||
|
case INT:
|
||||||
|
IntPointer pInt32 = new IntPointer(tens.data()).capacity(size);
|
||||||
|
Indexer int32Indexer = IntIndexer.create(pInt32);
|
||||||
|
buffer = Nd4j.createBuffer(pInt32, type, size, int32Indexer);
|
||||||
|
break;
|
||||||
|
case LONG:
|
||||||
|
LongPointer pInt64 = new LongPointer(tens.data()).capacity(size);
|
||||||
|
Indexer int64Indexer = LongIndexer.create(pInt64);
|
||||||
|
buffer = Nd4j.createBuffer(pInt64, type, size, int64Indexer);
|
||||||
|
break;
|
||||||
|
case UBYTE:
|
||||||
|
BytePointer pUint8 = new BytePointer(tens.data()).capacity(size);
|
||||||
|
Indexer uint8Indexer = UByteIndexer.create(pUint8);
|
||||||
|
buffer = Nd4j.createBuffer(pUint8, type, size, uint8Indexer);
|
||||||
|
break;
|
||||||
|
case UINT16:
|
||||||
|
ShortPointer pUint16 = new ShortPointer(tens.data()).capacity(size);
|
||||||
|
Indexer uint16Indexer = UShortIndexer.create(pUint16);
|
||||||
|
buffer = Nd4j.createBuffer(pUint16, type, size, uint16Indexer);
|
||||||
|
break;
|
||||||
|
case UINT32:
|
||||||
|
IntPointer pUint32 = new IntPointer(tens.data()).capacity(size);
|
||||||
|
Indexer uint32Indexer = UIntIndexer.create(pUint32);
|
||||||
|
buffer = Nd4j.createBuffer(pUint32, type, size, uint32Indexer);
|
||||||
|
break;
|
||||||
|
case UINT64:
|
||||||
|
LongPointer pUint64 = new LongPointer(tens.data()).capacity(size);
|
||||||
|
Indexer uint64Indexer = LongIndexer.create(pUint64);
|
||||||
|
buffer = Nd4j.createBuffer(pUint64, type, size, uint64Indexer);
|
||||||
|
break;
|
||||||
|
case HALF:
|
||||||
|
ShortPointer pFloat16 = new ShortPointer(tens.data()).capacity(size);
|
||||||
|
Indexer float16Indexer = HalfIndexer.create(pFloat16);
|
||||||
|
buffer = Nd4j.createBuffer(pFloat16, type, size, float16Indexer);
|
||||||
|
break;
|
||||||
|
case FLOAT:
|
||||||
|
FloatPointer pFloat = new FloatPointer(tens.data()).capacity(size);
|
||||||
|
FloatIndexer floatIndexer = FloatIndexer.create(pFloat);
|
||||||
|
buffer = Nd4j.createBuffer(pFloat, type, size, floatIndexer);
|
||||||
|
break;
|
||||||
|
case DOUBLE:
|
||||||
|
DoublePointer pDouble = new DoublePointer(tens.data()).capacity(size);
|
||||||
|
Indexer doubleIndexer = DoubleIndexer.create(pDouble);
|
||||||
|
buffer = Nd4j.createBuffer(pDouble, type, size, doubleIndexer);
|
||||||
|
break;
|
||||||
|
case BFLOAT16:
|
||||||
|
ShortPointer pBfloat16 = new ShortPointer(tens.data()).capacity(size);
|
||||||
|
Indexer bfloat16Indexer = Bfloat16Indexer.create(pBfloat16);
|
||||||
|
buffer = Nd4j.createBuffer(pBfloat16, type, size, bfloat16Indexer);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new RuntimeException("Unsupported data type encountered");
|
||||||
|
}
|
||||||
|
return buffer;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -0,0 +1,101 @@
|
||||||
|
/*
|
||||||
|
* ******************************************************************************
|
||||||
|
* *
|
||||||
|
* *
|
||||||
|
* * This program and the accompanying materials are made available under the
|
||||||
|
* * terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
* *
|
||||||
|
* * See the NOTICE file distributed with this work for additional
|
||||||
|
* * information regarding copyright ownership.
|
||||||
|
* * Unless required by applicable law or agreed to in writing, software
|
||||||
|
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* * License for the specific language governing permissions and limitations
|
||||||
|
* * under the License.
|
||||||
|
* *
|
||||||
|
* * SPDX-License-Identifier: Apache-2.0
|
||||||
|
* *****************************************************************************
|
||||||
|
*/
|
||||||
|
package org.nd4j.tvm.runner;
|
||||||
|
|
||||||
|
import org.bytedeco.javacpp.*;
|
||||||
|
import org.bytedeco.cpython.*;
|
||||||
|
import org.bytedeco.numpy.*;
|
||||||
|
import org.bytedeco.tvm.*;
|
||||||
|
import org.junit.Rule;
|
||||||
|
import org.junit.Test;
|
||||||
|
import org.junit.rules.TemporaryFolder;
|
||||||
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.util.Arrays;
|
||||||
|
import java.util.LinkedHashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import static org.bytedeco.cpython.global.python.*;
|
||||||
|
import static org.bytedeco.numpy.global.numpy.*;
|
||||||
|
import static org.bytedeco.tvm.global.tvm_runtime.*;
|
||||||
|
import static org.junit.Assert.assertEquals;
|
||||||
|
|
||||||
|
public class TvmRunnerTests {
|
||||||
|
|
||||||
|
@Rule
|
||||||
|
public TemporaryFolder testDir = new TemporaryFolder();
|
||||||
|
|
||||||
|
static void PrepareTestLibs(String libPath) throws Exception {
|
||||||
|
Py_AddPath(org.bytedeco.tvm.presets.tvm.cachePackages());
|
||||||
|
Py_Initialize();
|
||||||
|
if (_import_array() < 0) {
|
||||||
|
System.err.println("numpy.core.multiarray failed to import");
|
||||||
|
PyErr_Print();
|
||||||
|
System.exit(-1);
|
||||||
|
}
|
||||||
|
PyObject globals = PyModule_GetDict(PyImport_AddModule("__main__"));
|
||||||
|
|
||||||
|
PyRun_StringFlags("\"\"\"Script to prepare test_relay_add.so\"\"\"\n"
|
||||||
|
+ "import tvm\n"
|
||||||
|
+ "import numpy as np\n"
|
||||||
|
+ "from tvm import relay\n"
|
||||||
|
+ "import os\n"
|
||||||
|
|
||||||
|
+ "x = relay.var(\"x\", shape=(1, 1), dtype=\"float32\")\n"
|
||||||
|
+ "y = relay.var(\"y\", shape=(1, 1), dtype=\"float32\")\n"
|
||||||
|
+ "params = {\"y\": np.ones((1, 1), dtype=\"float32\")}\n"
|
||||||
|
+ "mod = tvm.IRModule.from_expr(relay.Function([x, y], x + y))\n"
|
||||||
|
+ "# build a module\n"
|
||||||
|
+ "compiled_lib = relay.build(mod, tvm.target.create(\"llvm\"), params=params)\n"
|
||||||
|
+ "# export it as a shared library\n"
|
||||||
|
+ "dylib_path = os.path.join(\"" + libPath + "\", \"test_relay_add.so\")\n"
|
||||||
|
+ "compiled_lib.export_library(dylib_path)\n",
|
||||||
|
|
||||||
|
Py_file_input, globals, globals, null);
|
||||||
|
|
||||||
|
if (PyErr_Occurred() != null) {
|
||||||
|
System.err.println("Python error occurred");
|
||||||
|
PyErr_Print();
|
||||||
|
System.exit(-1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testAdd() throws Exception {
|
||||||
|
/* try to use MKL when available */
|
||||||
|
System.setProperty("org.bytedeco.openblas.load", "mkl");
|
||||||
|
|
||||||
|
File libPath = testDir.newFolder("lib");
|
||||||
|
PrepareTestLibs(libPath.getAbsolutePath().replace(File.separatorChar, '/'));
|
||||||
|
File f = new File(libPath, "test_relay_add.so");
|
||||||
|
INDArray x = Nd4j.scalar(1.0f).reshape(1,1);
|
||||||
|
TvmRunner tvmRunner = TvmRunner.builder()
|
||||||
|
.modelUri(f.getAbsolutePath())
|
||||||
|
.build();
|
||||||
|
Map<String,INDArray> inputs = new LinkedHashMap<>();
|
||||||
|
inputs.put("x",x);
|
||||||
|
Map<String, INDArray> exec = tvmRunner.exec(inputs);
|
||||||
|
INDArray z = exec.get("0");
|
||||||
|
assertEquals(2.0,z.sumNumber().doubleValue(),1e-1);
|
||||||
|
}
|
||||||
|
}
|
|
@ -46,6 +46,7 @@
|
||||||
<module>nd4j-parameter-server-parent</module>
|
<module>nd4j-parameter-server-parent</module>
|
||||||
<module>nd4j-tensorflow</module>
|
<module>nd4j-tensorflow</module>
|
||||||
<module>nd4j-onnxruntime</module>
|
<module>nd4j-onnxruntime</module>
|
||||||
|
<module>nd4j-tvm</module>
|
||||||
<module>nd4j-common-tests</module>
|
<module>nd4j-common-tests</module>
|
||||||
<module>samediff-import</module>
|
<module>samediff-import</module>
|
||||||
</modules>
|
</modules>
|
||||||
|
|
25
pom.xml
25
pom.xml
|
@ -173,32 +173,35 @@
|
||||||
<javacpp.platform.sysroot/> <!-- -Djavacpp.platform.sysroot=$(xcrun -sdk iphoneos -show-sdk-path) -->
|
<javacpp.platform.sysroot/> <!-- -Djavacpp.platform.sysroot=$(xcrun -sdk iphoneos -show-sdk-path) -->
|
||||||
<javacpp.platform.extension/> <!-- -Djavacpp.platform.extension=-avx512 -->
|
<javacpp.platform.extension/> <!-- -Djavacpp.platform.extension=-avx512 -->
|
||||||
<javacpp.platform.properties>${javacpp.platform}</javacpp.platform.properties>
|
<javacpp.platform.properties>${javacpp.platform}</javacpp.platform.properties>
|
||||||
|
|
||||||
|
|
||||||
|
<javacpp.version>1.5.5</javacpp.version>
|
||||||
|
<javacpp-presets.version>1.5.5</javacpp-presets.version>
|
||||||
|
<javacv.version>1.5.5</javacv.version>
|
||||||
<javacpp.platform.additionalIncludePaths />
|
<javacpp.platform.additionalIncludePaths />
|
||||||
<javacpp.platform.cppincludepath />
|
<javacpp.platform.cppincludepath />
|
||||||
<javacpp.platform.library.path />
|
<javacpp.platform.library.path />
|
||||||
<!-- Used in nd4j-backend-impls for directory-maven-plugin to assist with finding native libs for tests -->
|
<!-- Used in nd4j-backend-impls for directory-maven-plugin to assist with finding native libs for tests -->
|
||||||
<nd4j.native.basedir />
|
<nd4j.native.basedir />
|
||||||
<nd4j.cuda.basedir />
|
<nd4j.cuda.basedir />
|
||||||
<javacpp.version>1.5.4</javacpp.version>
|
|
||||||
<javacpp-presets.version>1.5.4</javacpp-presets.version>
|
|
||||||
<javacv.version>1.5.4</javacv.version>
|
|
||||||
|
|
||||||
<python.version>3.7.9</python.version>
|
|
||||||
|
<python.version>3.9.1</python.version>
|
||||||
<cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version>
|
<cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version>
|
||||||
<numpy.version>1.19.1</numpy.version>
|
<numpy.version>1.20.1</numpy.version>
|
||||||
<numpy.javacpp.version>${numpy.version}-${javacpp-presets.version}</numpy.javacpp.version>
|
<numpy.javacpp.version>${numpy.version}-${javacpp-presets.version}</numpy.javacpp.version>
|
||||||
|
|
||||||
<openblas.version>0.3.10</openblas.version>
|
<openblas.version>0.3.13</openblas.version>
|
||||||
|
<mkl.version>2021.1</mkl.version>
|
||||||
<mkl.version>2020.3</mkl.version>
|
<opencv.version>4.5.1</opencv.version>
|
||||||
<opencv.version>4.4.0</opencv.version>
|
|
||||||
<ffmpeg.version>4.3.1</ffmpeg.version>
|
<ffmpeg.version>4.3.1</ffmpeg.version>
|
||||||
<leptonica.version>1.80.0</leptonica.version>
|
<leptonica.version>1.80.0</leptonica.version>
|
||||||
<hdf5.version>1.12.0</hdf5.version>
|
<hdf5.version>1.12.0</hdf5.version>
|
||||||
<ale.version>0.6.1</ale.version>
|
<ale.version>0.6.1</ale.version>
|
||||||
<gym.version>0.17.2</gym.version>
|
<gym.version>0.18.0</gym.version>
|
||||||
<tensorflow.version>1.15.3</tensorflow.version>
|
<tensorflow.version>1.15.5</tensorflow.version>
|
||||||
<tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version>
|
<tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version>
|
||||||
|
|
||||||
<archunit.version>0.14.1</archunit.version>
|
<archunit.version>0.14.1</archunit.version>
|
||||||
<commons-compress.version>1.18</commons-compress.version>
|
<commons-compress.version>1.18</commons-compress.version>
|
||||||
<commonsmath.version>3.5</commonsmath.version>
|
<commonsmath.version>3.5</commonsmath.version>
|
||||||
|
|
Loading…
Reference in New Issue