Merge pull request #9185 from eclipse/sa_tvm

Add nd4j-tvm module with initial inference support using TVM
master
Adam Gibson 2021-03-09 07:54:15 +09:00 committed by GitHub
commit aded258340
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 612 additions and 13 deletions

View File

@ -23,6 +23,10 @@
#ifndef LIBND4J_BLAS_HELPER_H #ifndef LIBND4J_BLAS_HELPER_H
#define LIBND4J_BLAS_HELPER_H #define LIBND4J_BLAS_HELPER_H
// work around conflict with OpenBLAS
struct bfloat16;
#define BFLOAT16 BFLOAT16
#include <system/pointercast.h> #include <system/pointercast.h>
#include <types/float16.h> #include <types/float16.h>
#include <cblas.h> #include <cblas.h>

View File

@ -23,6 +23,10 @@
#ifndef LIBND4J_GEMM_H #ifndef LIBND4J_GEMM_H
#define LIBND4J_GEMM_H #define LIBND4J_GEMM_H
// work around conflict with OpenBLAS
struct bfloat16;
#define BFLOAT16 BFLOAT16
#include <cblas.h> #include <cblas.h>
#include <math/templatemath.h> #include <math/templatemath.h>
#include <system/op_boilerplate.h> #include <system/op_boilerplate.h>

View File

@ -139,6 +139,13 @@ import java.util.Scanner;
"ops/declarable/headers/loss.h", "ops/declarable/headers/loss.h",
"ops/declarable/headers/datatypes.h", "ops/declarable/headers/datatypes.h",
"ops/declarable/headers/third_party.h", "ops/declarable/headers/third_party.h",
"openblas_config.h",
"cblas.h",
"lapacke_config.h",
"lapacke_mangling.h",
"lapack.h",
"lapacke.h",
"lapacke_utils.h",
"cnpy/cnpy.h" "cnpy/cnpy.h"
}, },
compiler = {"cpp11", "nowarnings"}, compiler = {"cpp11", "nowarnings"},
@ -166,6 +173,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled {
public void map(InfoMap infoMap) { public void map(InfoMap infoMap) {
infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE", infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE",
"_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations()) "_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations())
.put(new Info("openblas_config.h", "cblas.h", "lapacke_config.h", "lapacke_mangling.h", "lapack.h", "lapacke.h", "lapacke_utils.h").skip())
.put(new Info("NativeOps.h", "build_info.h").objectify()) .put(new Info("NativeOps.h", "build_info.h").objectify())
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack")) .put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper")) .put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))

View File

@ -38,7 +38,7 @@
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.8</maven.compiler.source> <maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target> <maven.compiler.target>1.8</maven.compiler.target>
<onnxruntime.version>1.4.0</onnxruntime.version> <onnxruntime.version>1.6.0</onnxruntime.version>
<onnxruntime.javacpp.version>${onnxruntime.version}-${javacpp.version}</onnxruntime.javacpp.version> <onnxruntime.javacpp.version>${onnxruntime.version}-${javacpp.version}</onnxruntime.javacpp.version>
</properties> </properties>

81
nd4j/nd4j-tvm/pom.xml Normal file
View File

@ -0,0 +1,81 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>nd4j</artifactId>
<groupId>org.nd4j</groupId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>nd4j-tvm</artifactId>
<name>nd4j-tvm</name>
<properties>
<tvm.version>0.7.0</tvm.version>
<tvm.javacpp.version>${tvm.version}-${javacpp-presets.version}</tvm.javacpp.version>
</properties>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>mkl-platform-redist</artifactId>
<version>${mkl.version}-${javacpp-presets.version}</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>tvm-platform</artifactId>
<version>${tvm.javacpp.version}</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>tvm</artifactId>
<version>${tvm.javacpp.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
<profile>
<id>testresources</id>
</profile>
</profiles>
</project>

View File

@ -0,0 +1,164 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.tvm.runner;
import java.io.Closeable;
import java.util.LinkedHashMap;
import java.util.Map;
import lombok.Builder;
import lombok.extern.slf4j.Slf4j;
import org.bytedeco.javacpp.*;
import org.bytedeco.tvm.*;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.tvm.util.TVMUtils;
import static org.bytedeco.tvm.global.tvm_runtime.*;
import static org.nd4j.tvm.util.TVMUtils.*;
@Slf4j
public class TvmRunner implements Closeable {
private static DLContext ctx;
private Module modFactory;
private TVMValue values;
private IntPointer codes;
private TVMArgsSetter setter;
private TVMRetValue rv;
private Module gmod;
private PackedFunc getNumInputs;
private PackedFunc getNumOutputs;
private PackedFunc setInput;
private PackedFunc getOutput;
private PackedFunc run;
@Builder
public TvmRunner(String modelUri) {
if (ctx == null) {
ctx = new DLContext().device_type(kDLCPU).device_id(0);
ctx.retainReference();
}
// create the runtime module
try (PointerScope scope = new PointerScope()) {
modFactory = Module.LoadFromFile(modelUri);
values = new TVMValue(2);
codes = new IntPointer(2);
setter = new TVMArgsSetter(values, codes);
setter.apply(0, ctx);
rv = new TVMRetValue();
modFactory.GetFunction("default").CallPacked(new TVMArgs(values, codes, 1), rv);
gmod = rv.asModule();
getNumInputs = gmod.GetFunction("get_num_inputs");
getNumOutputs = gmod.GetFunction("get_num_outputs");
setInput = gmod.GetFunction("set_input");
getOutput = gmod.GetFunction("get_output");
run = gmod.GetFunction("run");
// retain the session reference to prevent pre emptive release of the session.
modFactory.retainReference();
values.retainReference();
codes.retainReference();
setter.retainReference();
rv.retainReference();
gmod.retainReference();
getNumInputs.retainReference();
getNumOutputs.retainReference();
setInput.retainReference();
getOutput.retainReference();
run.retainReference();
}
}
@Override
public void close() {
if (run != null) {
run.releaseReference();
}
if (getOutput != null) {
getOutput.releaseReference();
}
if (setInput != null) {
setInput.releaseReference();
}
if (getNumOutputs != null) {
getNumOutputs.releaseReference();
}
if (getNumInputs != null) {
getNumInputs.releaseReference();
}
if (gmod != null) {
gmod.releaseReference();
}
if (rv != null) {
rv.releaseReference();
}
if (setter != null) {
setter.releaseReference();
}
if (codes != null) {
codes.releaseReference();
}
if (values != null) {
values.releaseReference();
}
if (modFactory != null) {
modFactory.releaseReference();
}
}
/**
* Execute the {@link #run} function
* using the given input {@link Map}
* @param input the input map
* @return a map of the names of the ndarrays
*/
public Map<String,INDArray> exec(Map<String,INDArray> input) {
try (PointerScope scope = new PointerScope()) {
getNumInputs.CallPacked(new TVMArgs(values, codes, 0), rv);
long numInputNodes = rv.asLong();
getNumOutputs.CallPacked(new TVMArgs(values, codes, 0), rv);
long numOutputNodes = rv.asLong();
// set the right input
for (Map.Entry<String,INDArray> e : input.entrySet()) {
String name = e.getKey();
INDArray arr = e.getValue();
DLTensor inputTensor = getTensor(arr, ctx);
Preconditions.checkState(inputTensor != null,"Input must be a tensor.");
setter.apply(0, new BytePointer(name));
setter.apply(1, inputTensor);
setInput.CallPacked(new TVMArgs(values, codes, 2), rv);
}
// run the code
run.CallPacked(new TVMArgs(values, codes, 0), rv);
Map<String, INDArray> ret = new LinkedHashMap<>();
// get the output
for (int i = 0; i < numOutputNodes; i++) {
setter.apply(0, i);
getOutput.CallPacked(new TVMArgs(values, codes, 1), rv);
DLTensor outputTensor = rv.asDLTensor();
ret.put(Integer.toString(i), getArray(outputTensor));
}
return ret;
}
}
}

View File

@ -0,0 +1,233 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.tvm.util;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.indexer.*;
import org.bytedeco.tvm.*;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import static org.bytedeco.tvm.global.tvm_runtime.*;
import static org.nd4j.linalg.api.buffer.DataType.*;
public class TVMUtils {
/**
* Return a {@link DataType}
* for the tvm data type
* @param dataType the equivalent nd4j data type
* @return
*/
public static DataType dataTypeForTvmType(DLDataType dataType) {
if(dataType.code() == kDLInt && dataType.bits() == 8) {
return INT8;
} else if(dataType.code() == kDLInt && dataType.bits() == 16) {
return INT16;
} else if(dataType.code() == kDLInt && dataType.bits() == 32) {
return INT32;
} else if(dataType.code() == kDLInt && dataType.bits() == 64) {
return INT64;
} else if(dataType.code() == kDLUInt && dataType.bits() == 8) {
return UINT8;
} else if(dataType.code() == kDLUInt && dataType.bits() == 16) {
return UINT16;
} else if(dataType.code() == kDLUInt && dataType.bits() == 32) {
return UINT32;
} else if(dataType.code() == kDLUInt && dataType.bits() == 64) {
return UINT64;
} else if(dataType.code() == kDLFloat && dataType.bits() == 16) {
return FLOAT16;
} else if(dataType.code() == kDLFloat && dataType.bits() == 32) {
return FLOAT;
} else if(dataType.code() == kDLFloat && dataType.bits() == 64) {
return DOUBLE;
} else if(dataType.code() == kDLBfloat && dataType.bits() == 16) {
return BFLOAT16;
} else
throw new IllegalArgumentException("Illegal data type code " + dataType.code() + " with bits " + dataType.bits());
}
/**
* Convert the tvm type for the given data type
* @param dataType
* @return
*/
public static DLDataType tvmTypeForDataType(DataType dataType) {
if(dataType == INT8) {
return new DLDataType().code((byte)kDLInt).bits((byte)8).lanes((short)1);
} else if(dataType == INT16) {
return new DLDataType().code((byte)kDLInt).bits((byte)16).lanes((short)1);
} else if(dataType == INT32) {
return new DLDataType().code((byte)kDLInt).bits((byte)32).lanes((short)1);
} else if(dataType == INT64) {
return new DLDataType().code((byte)kDLInt).bits((byte)64).lanes((short)1);
} else if(dataType == UINT8) {
return new DLDataType().code((byte)kDLUInt).bits((byte)8).lanes((short)1);
} else if(dataType == UINT16) {
return new DLDataType().code((byte)kDLUInt).bits((byte)16).lanes((short)1);
} else if(dataType == UINT32) {
return new DLDataType().code((byte)kDLUInt).bits((byte)32).lanes((short)1);
} else if(dataType == UINT64) {
return new DLDataType().code((byte)kDLUInt).bits((byte)64).lanes((short)1);
} else if(dataType == FLOAT16) {
return new DLDataType().code((byte)kDLFloat).bits((byte)16).lanes((short)1);
} else if(dataType == FLOAT) {
return new DLDataType().code((byte)kDLFloat).bits((byte)32).lanes((short)1);
} else if(dataType == DOUBLE) {
return new DLDataType().code((byte)kDLFloat).bits((byte)64).lanes((short)1);
} else if(dataType == BFLOAT16) {
return new DLDataType().code((byte)kDLBfloat).bits((byte)16).lanes((short)1);
} else
throw new IllegalArgumentException("Illegal data type " + dataType);
}
/**
* Convert an tvm {@link DLTensor}
* in to an {@link INDArray}
* @param value the tensor to convert
* @return
*/
public static INDArray getArray(DLTensor value) {
DataType dataType = dataTypeForTvmType(value.dtype());
LongPointer shape = value.shape();
LongPointer stride = value.strides();
long[] shapeConvert;
if(shape != null) {
shapeConvert = new long[value.ndim()];
shape.get(shapeConvert);
} else {
shapeConvert = new long[]{1};
}
long[] strideConvert;
if(stride != null) {
strideConvert = new long[value.ndim()];
stride.get(strideConvert);
} else {
strideConvert = Nd4j.getStrides(shapeConvert);
}
long size = 1;
for (int i = 0; i < shapeConvert.length; i++) {
size *= shapeConvert[i];
}
size *= value.dtype().bits() / 8;
DataBuffer getBuffer = getDataBuffer(value,size);
Preconditions.checkState(dataType.equals(getBuffer.dataType()),"Data type must be equivalent as specified by the tvm metadata.");
return Nd4j.create(getBuffer,shapeConvert,strideConvert,0);
}
/**
* Get an tvm tensor from an ndarray.
* @param ndArray the ndarray to get the value from
* @param ctx the {@link DLContext} to use.
* @return
*/
public static DLTensor getTensor(INDArray ndArray, DLContext ctx) {
DLTensor ret = new DLTensor();
ret.data(ndArray.data().pointer());
ret.ctx(ctx);
ret.ndim(ndArray.rank());
ret.dtype(tvmTypeForDataType(ndArray.dataType()));
ret.shape(new LongPointer(ndArray.shape()));
ret.strides(new LongPointer(ndArray.stride()));
ret.byte_offset(ndArray.offset());
return ret;
}
/**
* Get the data buffer from the given value
* @param tens the values to get
* @return the equivalent data buffer
*/
public static DataBuffer getDataBuffer(DLTensor tens, long size) {
DataBuffer buffer = null;
DataType type = dataTypeForTvmType(tens.dtype());
switch (type) {
case BYTE:
BytePointer pInt8 = new BytePointer(tens.data()).capacity(size);
Indexer int8Indexer = ByteIndexer.create(pInt8);
buffer = Nd4j.createBuffer(pInt8, type, size, int8Indexer);
break;
case SHORT:
ShortPointer pInt16 = new ShortPointer(tens.data()).capacity(size);
Indexer int16Indexer = ShortIndexer.create(pInt16);
buffer = Nd4j.createBuffer(pInt16, type, size, int16Indexer);
break;
case INT:
IntPointer pInt32 = new IntPointer(tens.data()).capacity(size);
Indexer int32Indexer = IntIndexer.create(pInt32);
buffer = Nd4j.createBuffer(pInt32, type, size, int32Indexer);
break;
case LONG:
LongPointer pInt64 = new LongPointer(tens.data()).capacity(size);
Indexer int64Indexer = LongIndexer.create(pInt64);
buffer = Nd4j.createBuffer(pInt64, type, size, int64Indexer);
break;
case UBYTE:
BytePointer pUint8 = new BytePointer(tens.data()).capacity(size);
Indexer uint8Indexer = UByteIndexer.create(pUint8);
buffer = Nd4j.createBuffer(pUint8, type, size, uint8Indexer);
break;
case UINT16:
ShortPointer pUint16 = new ShortPointer(tens.data()).capacity(size);
Indexer uint16Indexer = UShortIndexer.create(pUint16);
buffer = Nd4j.createBuffer(pUint16, type, size, uint16Indexer);
break;
case UINT32:
IntPointer pUint32 = new IntPointer(tens.data()).capacity(size);
Indexer uint32Indexer = UIntIndexer.create(pUint32);
buffer = Nd4j.createBuffer(pUint32, type, size, uint32Indexer);
break;
case UINT64:
LongPointer pUint64 = new LongPointer(tens.data()).capacity(size);
Indexer uint64Indexer = LongIndexer.create(pUint64);
buffer = Nd4j.createBuffer(pUint64, type, size, uint64Indexer);
break;
case HALF:
ShortPointer pFloat16 = new ShortPointer(tens.data()).capacity(size);
Indexer float16Indexer = HalfIndexer.create(pFloat16);
buffer = Nd4j.createBuffer(pFloat16, type, size, float16Indexer);
break;
case FLOAT:
FloatPointer pFloat = new FloatPointer(tens.data()).capacity(size);
FloatIndexer floatIndexer = FloatIndexer.create(pFloat);
buffer = Nd4j.createBuffer(pFloat, type, size, floatIndexer);
break;
case DOUBLE:
DoublePointer pDouble = new DoublePointer(tens.data()).capacity(size);
Indexer doubleIndexer = DoubleIndexer.create(pDouble);
buffer = Nd4j.createBuffer(pDouble, type, size, doubleIndexer);
break;
case BFLOAT16:
ShortPointer pBfloat16 = new ShortPointer(tens.data()).capacity(size);
Indexer bfloat16Indexer = Bfloat16Indexer.create(pBfloat16);
buffer = Nd4j.createBuffer(pBfloat16, type, size, bfloat16Indexer);
break;
default:
throw new RuntimeException("Unsupported data type encountered");
}
return buffer;
}
}

View File

@ -0,0 +1,101 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.nd4j.tvm.runner;
import org.bytedeco.javacpp.*;
import org.bytedeco.cpython.*;
import org.bytedeco.numpy.*;
import org.bytedeco.tvm.*;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.File;
import java.util.Arrays;
import java.util.LinkedHashMap;
import java.util.Map;
import static org.bytedeco.cpython.global.python.*;
import static org.bytedeco.numpy.global.numpy.*;
import static org.bytedeco.tvm.global.tvm_runtime.*;
import static org.junit.Assert.assertEquals;
public class TvmRunnerTests {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
static void PrepareTestLibs(String libPath) throws Exception {
Py_AddPath(org.bytedeco.tvm.presets.tvm.cachePackages());
Py_Initialize();
if (_import_array() < 0) {
System.err.println("numpy.core.multiarray failed to import");
PyErr_Print();
System.exit(-1);
}
PyObject globals = PyModule_GetDict(PyImport_AddModule("__main__"));
PyRun_StringFlags("\"\"\"Script to prepare test_relay_add.so\"\"\"\n"
+ "import tvm\n"
+ "import numpy as np\n"
+ "from tvm import relay\n"
+ "import os\n"
+ "x = relay.var(\"x\", shape=(1, 1), dtype=\"float32\")\n"
+ "y = relay.var(\"y\", shape=(1, 1), dtype=\"float32\")\n"
+ "params = {\"y\": np.ones((1, 1), dtype=\"float32\")}\n"
+ "mod = tvm.IRModule.from_expr(relay.Function([x, y], x + y))\n"
+ "# build a module\n"
+ "compiled_lib = relay.build(mod, tvm.target.create(\"llvm\"), params=params)\n"
+ "# export it as a shared library\n"
+ "dylib_path = os.path.join(\"" + libPath + "\", \"test_relay_add.so\")\n"
+ "compiled_lib.export_library(dylib_path)\n",
Py_file_input, globals, globals, null);
if (PyErr_Occurred() != null) {
System.err.println("Python error occurred");
PyErr_Print();
System.exit(-1);
}
}
@Test
public void testAdd() throws Exception {
/* try to use MKL when available */
System.setProperty("org.bytedeco.openblas.load", "mkl");
File libPath = testDir.newFolder("lib");
PrepareTestLibs(libPath.getAbsolutePath().replace(File.separatorChar, '/'));
File f = new File(libPath, "test_relay_add.so");
INDArray x = Nd4j.scalar(1.0f).reshape(1,1);
TvmRunner tvmRunner = TvmRunner.builder()
.modelUri(f.getAbsolutePath())
.build();
Map<String,INDArray> inputs = new LinkedHashMap<>();
inputs.put("x",x);
Map<String, INDArray> exec = tvmRunner.exec(inputs);
INDArray z = exec.get("0");
assertEquals(2.0,z.sumNumber().doubleValue(),1e-1);
}
}

View File

@ -46,6 +46,7 @@
<module>nd4j-parameter-server-parent</module> <module>nd4j-parameter-server-parent</module>
<module>nd4j-tensorflow</module> <module>nd4j-tensorflow</module>
<module>nd4j-onnxruntime</module> <module>nd4j-onnxruntime</module>
<module>nd4j-tvm</module>
<module>nd4j-common-tests</module> <module>nd4j-common-tests</module>
<module>samediff-import</module> <module>samediff-import</module>
</modules> </modules>

25
pom.xml
View File

@ -173,32 +173,35 @@
<javacpp.platform.sysroot/> <!-- -Djavacpp.platform.sysroot=$(xcrun -sdk iphoneos -show-sdk-path) --> <javacpp.platform.sysroot/> <!-- -Djavacpp.platform.sysroot=$(xcrun -sdk iphoneos -show-sdk-path) -->
<javacpp.platform.extension/> <!-- -Djavacpp.platform.extension=-avx512 --> <javacpp.platform.extension/> <!-- -Djavacpp.platform.extension=-avx512 -->
<javacpp.platform.properties>${javacpp.platform}</javacpp.platform.properties> <javacpp.platform.properties>${javacpp.platform}</javacpp.platform.properties>
<javacpp.version>1.5.5</javacpp.version>
<javacpp-presets.version>1.5.5</javacpp-presets.version>
<javacv.version>1.5.5</javacv.version>
<javacpp.platform.additionalIncludePaths /> <javacpp.platform.additionalIncludePaths />
<javacpp.platform.cppincludepath /> <javacpp.platform.cppincludepath />
<javacpp.platform.library.path /> <javacpp.platform.library.path />
<!-- Used in nd4j-backend-impls for directory-maven-plugin to assist with finding native libs for tests --> <!-- Used in nd4j-backend-impls for directory-maven-plugin to assist with finding native libs for tests -->
<nd4j.native.basedir /> <nd4j.native.basedir />
<nd4j.cuda.basedir /> <nd4j.cuda.basedir />
<javacpp.version>1.5.4</javacpp.version>
<javacpp-presets.version>1.5.4</javacpp-presets.version>
<javacv.version>1.5.4</javacv.version>
<python.version>3.7.9</python.version>
<python.version>3.9.1</python.version>
<cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version> <cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version>
<numpy.version>1.19.1</numpy.version> <numpy.version>1.20.1</numpy.version>
<numpy.javacpp.version>${numpy.version}-${javacpp-presets.version}</numpy.javacpp.version> <numpy.javacpp.version>${numpy.version}-${javacpp-presets.version}</numpy.javacpp.version>
<openblas.version>0.3.10</openblas.version> <openblas.version>0.3.13</openblas.version>
<mkl.version>2021.1</mkl.version>
<mkl.version>2020.3</mkl.version> <opencv.version>4.5.1</opencv.version>
<opencv.version>4.4.0</opencv.version>
<ffmpeg.version>4.3.1</ffmpeg.version> <ffmpeg.version>4.3.1</ffmpeg.version>
<leptonica.version>1.80.0</leptonica.version> <leptonica.version>1.80.0</leptonica.version>
<hdf5.version>1.12.0</hdf5.version> <hdf5.version>1.12.0</hdf5.version>
<ale.version>0.6.1</ale.version> <ale.version>0.6.1</ale.version>
<gym.version>0.17.2</gym.version> <gym.version>0.18.0</gym.version>
<tensorflow.version>1.15.3</tensorflow.version> <tensorflow.version>1.15.5</tensorflow.version>
<tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version> <tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version>
<archunit.version>0.14.1</archunit.version> <archunit.version>0.14.1</archunit.version>
<commons-compress.version>1.18</commons-compress.version> <commons-compress.version>1.18</commons-compress.version>
<commonsmath.version>3.5</commonsmath.version> <commonsmath.version>3.5</commonsmath.version>