Add nd4j-tvm module with initial inference support using TVM
Also update versions for JavaCPP and JavaCV to 1.5.5-SNAPSHOT Signed-off-by: Samuel Audet <samuel.audet@gmail.com>master
parent
c523c4f0c7
commit
8891d4d3bc
|
@ -23,6 +23,10 @@
|
|||
#ifndef LIBND4J_BLAS_HELPER_H
|
||||
#define LIBND4J_BLAS_HELPER_H
|
||||
|
||||
// work around conflict with OpenBLAS
|
||||
struct bfloat16;
|
||||
#define BFLOAT16 BFLOAT16
|
||||
|
||||
#include <system/pointercast.h>
|
||||
#include <types/float16.h>
|
||||
#include <cblas.h>
|
||||
|
|
|
@ -23,6 +23,10 @@
|
|||
#ifndef LIBND4J_GEMM_H
|
||||
#define LIBND4J_GEMM_H
|
||||
|
||||
// work around conflict with OpenBLAS
|
||||
struct bfloat16;
|
||||
#define BFLOAT16 BFLOAT16
|
||||
|
||||
#include <cblas.h>
|
||||
#include <math/templatemath.h>
|
||||
#include <system/op_boilerplate.h>
|
||||
|
|
|
@ -139,6 +139,13 @@ import java.util.Scanner;
|
|||
"ops/declarable/headers/loss.h",
|
||||
"ops/declarable/headers/datatypes.h",
|
||||
"ops/declarable/headers/third_party.h",
|
||||
"openblas_config.h",
|
||||
"cblas.h",
|
||||
"lapacke_config.h",
|
||||
"lapacke_mangling.h",
|
||||
"lapack.h",
|
||||
"lapacke.h",
|
||||
"lapacke_utils.h",
|
||||
"cnpy/cnpy.h"
|
||||
},
|
||||
compiler = {"cpp11", "nowarnings"},
|
||||
|
@ -166,6 +173,7 @@ public class Nd4jCpuPresets implements InfoMapper, BuildEnabled {
|
|||
public void map(InfoMap infoMap) {
|
||||
infoMap.put(new Info("thread_local", "ND4J_EXPORT", "INLINEDEF", "CUBLASWINAPI", "FORCEINLINE",
|
||||
"_CUDA_H", "_CUDA_D", "_CUDA_G", "_CUDA_HD", "LIBND4J_ALL_OPS", "NOT_EXCLUDED").cppTypes().annotations())
|
||||
.put(new Info("openblas_config.h", "cblas.h", "lapacke_config.h", "lapacke_mangling.h", "lapack.h", "lapacke.h", "lapacke_utils.h").skip())
|
||||
.put(new Info("NativeOps.h", "build_info.h").objectify())
|
||||
.put(new Info("OpaqueTadPack").pointerTypes("OpaqueTadPack"))
|
||||
.put(new Info("OpaqueResultWrapper").pointerTypes("OpaqueResultWrapper"))
|
||||
|
|
|
@ -38,7 +38,7 @@
|
|||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
<onnxruntime.version>1.4.0</onnxruntime.version>
|
||||
<onnxruntime.version>1.6.0</onnxruntime.version>
|
||||
<onnxruntime.javacpp.version>${onnxruntime.version}-${javacpp.version}</onnxruntime.javacpp.version>
|
||||
</properties>
|
||||
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
<parent>
|
||||
<artifactId>nd4j</artifactId>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<artifactId>nd4j-tvm</artifactId>
|
||||
<name>nd4j-tvm</name>
|
||||
|
||||
<properties>
|
||||
<tvm.version>0.7.0</tvm.version>
|
||||
<tvm.javacpp.version>${tvm.version}-${javacpp-presets.version}</tvm.javacpp.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>mkl-platform-redist</artifactId>
|
||||
<version>${mkl.version}-${javacpp-presets.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>tvm-platform</artifactId>
|
||||
<version>${tvm.javacpp.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.bytedeco</groupId>
|
||||
<artifactId>tvm</artifactId>
|
||||
<version>${tvm.javacpp.version}</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>testresources</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -0,0 +1,164 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
package org.nd4j.tvm.runner;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
import lombok.Builder;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.bytedeco.javacpp.*;
|
||||
import org.bytedeco.tvm.*;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.tvm.util.TVMUtils;
|
||||
|
||||
import static org.bytedeco.tvm.global.tvm_runtime.*;
|
||||
import static org.nd4j.tvm.util.TVMUtils.*;
|
||||
|
||||
@Slf4j
|
||||
public class TvmRunner implements Closeable {
|
||||
private static DLContext ctx;
|
||||
private Module modFactory;
|
||||
private TVMValue values;
|
||||
private IntPointer codes;
|
||||
private TVMArgsSetter setter;
|
||||
private TVMRetValue rv;
|
||||
private Module gmod;
|
||||
private PackedFunc getNumInputs;
|
||||
private PackedFunc getNumOutputs;
|
||||
private PackedFunc setInput;
|
||||
private PackedFunc getOutput;
|
||||
private PackedFunc run;
|
||||
|
||||
@Builder
|
||||
public TvmRunner(String modelUri) {
|
||||
if (ctx == null) {
|
||||
ctx = new DLContext().device_type(kDLCPU).device_id(0);
|
||||
ctx.retainReference();
|
||||
}
|
||||
|
||||
// create the runtime module
|
||||
try (PointerScope scope = new PointerScope()) {
|
||||
modFactory = Module.LoadFromFile(modelUri);
|
||||
values = new TVMValue(2);
|
||||
codes = new IntPointer(2);
|
||||
setter = new TVMArgsSetter(values, codes);
|
||||
setter.apply(0, ctx);
|
||||
rv = new TVMRetValue();
|
||||
modFactory.GetFunction("default").CallPacked(new TVMArgs(values, codes, 1), rv);
|
||||
gmod = rv.asModule();
|
||||
getNumInputs = gmod.GetFunction("get_num_inputs");
|
||||
getNumOutputs = gmod.GetFunction("get_num_outputs");
|
||||
setInput = gmod.GetFunction("set_input");
|
||||
getOutput = gmod.GetFunction("get_output");
|
||||
run = gmod.GetFunction("run");
|
||||
// retain the session reference to prevent pre emptive release of the session.
|
||||
modFactory.retainReference();
|
||||
values.retainReference();
|
||||
codes.retainReference();
|
||||
setter.retainReference();
|
||||
rv.retainReference();
|
||||
gmod.retainReference();
|
||||
getNumInputs.retainReference();
|
||||
getNumOutputs.retainReference();
|
||||
setInput.retainReference();
|
||||
getOutput.retainReference();
|
||||
run.retainReference();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() {
|
||||
if (run != null) {
|
||||
run.releaseReference();
|
||||
}
|
||||
if (getOutput != null) {
|
||||
getOutput.releaseReference();
|
||||
}
|
||||
if (setInput != null) {
|
||||
setInput.releaseReference();
|
||||
}
|
||||
if (getNumOutputs != null) {
|
||||
getNumOutputs.releaseReference();
|
||||
}
|
||||
if (getNumInputs != null) {
|
||||
getNumInputs.releaseReference();
|
||||
}
|
||||
if (gmod != null) {
|
||||
gmod.releaseReference();
|
||||
}
|
||||
if (rv != null) {
|
||||
rv.releaseReference();
|
||||
}
|
||||
if (setter != null) {
|
||||
setter.releaseReference();
|
||||
}
|
||||
if (codes != null) {
|
||||
codes.releaseReference();
|
||||
}
|
||||
if (values != null) {
|
||||
values.releaseReference();
|
||||
}
|
||||
if (modFactory != null) {
|
||||
modFactory.releaseReference();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute the {@link #run} function
|
||||
* using the given input {@link Map}
|
||||
* @param input the input map
|
||||
* @return a map of the names of the ndarrays
|
||||
*/
|
||||
public Map<String,INDArray> exec(Map<String,INDArray> input) {
|
||||
try (PointerScope scope = new PointerScope()) {
|
||||
getNumInputs.CallPacked(new TVMArgs(values, codes, 0), rv);
|
||||
long numInputNodes = rv.asLong();
|
||||
getNumOutputs.CallPacked(new TVMArgs(values, codes, 0), rv);
|
||||
long numOutputNodes = rv.asLong();
|
||||
|
||||
// set the right input
|
||||
for (Map.Entry<String,INDArray> e : input.entrySet()) {
|
||||
String name = e.getKey();
|
||||
INDArray arr = e.getValue();
|
||||
DLTensor inputTensor = getTensor(arr, ctx);
|
||||
Preconditions.checkState(inputTensor != null,"Input must be a tensor.");
|
||||
setter.apply(0, new BytePointer(name));
|
||||
setter.apply(1, inputTensor);
|
||||
setInput.CallPacked(new TVMArgs(values, codes, 2), rv);
|
||||
}
|
||||
|
||||
// run the code
|
||||
run.CallPacked(new TVMArgs(values, codes, 0), rv);
|
||||
|
||||
Map<String, INDArray> ret = new LinkedHashMap<>();
|
||||
|
||||
// get the output
|
||||
for (int i = 0; i < numOutputNodes; i++) {
|
||||
setter.apply(0, i);
|
||||
getOutput.CallPacked(new TVMArgs(values, codes, 1), rv);
|
||||
DLTensor outputTensor = rv.asDLTensor();
|
||||
ret.put(Integer.toString(i), getArray(outputTensor));
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,233 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
package org.nd4j.tvm.util;
|
||||
|
||||
import org.bytedeco.javacpp.*;
|
||||
import org.bytedeco.javacpp.indexer.*;
|
||||
import org.bytedeco.tvm.*;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.bytedeco.tvm.global.tvm_runtime.*;
|
||||
import static org.nd4j.linalg.api.buffer.DataType.*;
|
||||
|
||||
public class TVMUtils {
|
||||
|
||||
/**
|
||||
* Return a {@link DataType}
|
||||
* for the tvm data type
|
||||
* @param dataType the equivalent nd4j data type
|
||||
* @return
|
||||
*/
|
||||
public static DataType dataTypeForTvmType(DLDataType dataType) {
|
||||
if(dataType.code() == kDLInt && dataType.bits() == 8) {
|
||||
return INT8;
|
||||
} else if(dataType.code() == kDLInt && dataType.bits() == 16) {
|
||||
return INT16;
|
||||
} else if(dataType.code() == kDLInt && dataType.bits() == 32) {
|
||||
return INT32;
|
||||
} else if(dataType.code() == kDLInt && dataType.bits() == 64) {
|
||||
return INT64;
|
||||
} else if(dataType.code() == kDLUInt && dataType.bits() == 8) {
|
||||
return UINT8;
|
||||
} else if(dataType.code() == kDLUInt && dataType.bits() == 16) {
|
||||
return UINT16;
|
||||
} else if(dataType.code() == kDLUInt && dataType.bits() == 32) {
|
||||
return UINT32;
|
||||
} else if(dataType.code() == kDLUInt && dataType.bits() == 64) {
|
||||
return UINT64;
|
||||
} else if(dataType.code() == kDLFloat && dataType.bits() == 16) {
|
||||
return FLOAT16;
|
||||
} else if(dataType.code() == kDLFloat && dataType.bits() == 32) {
|
||||
return FLOAT;
|
||||
} else if(dataType.code() == kDLFloat && dataType.bits() == 64) {
|
||||
return DOUBLE;
|
||||
} else if(dataType.code() == kDLBfloat && dataType.bits() == 16) {
|
||||
return BFLOAT16;
|
||||
} else
|
||||
throw new IllegalArgumentException("Illegal data type code " + dataType.code() + " with bits " + dataType.bits());
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert the tvm type for the given data type
|
||||
* @param dataType
|
||||
* @return
|
||||
*/
|
||||
public static DLDataType tvmTypeForDataType(DataType dataType) {
|
||||
if(dataType == INT8) {
|
||||
return new DLDataType().code((byte)kDLInt).bits((byte)8).lanes((short)1);
|
||||
} else if(dataType == INT16) {
|
||||
return new DLDataType().code((byte)kDLInt).bits((byte)16).lanes((short)1);
|
||||
} else if(dataType == INT32) {
|
||||
return new DLDataType().code((byte)kDLInt).bits((byte)32).lanes((short)1);
|
||||
} else if(dataType == INT64) {
|
||||
return new DLDataType().code((byte)kDLInt).bits((byte)64).lanes((short)1);
|
||||
} else if(dataType == UINT8) {
|
||||
return new DLDataType().code((byte)kDLUInt).bits((byte)8).lanes((short)1);
|
||||
} else if(dataType == UINT16) {
|
||||
return new DLDataType().code((byte)kDLUInt).bits((byte)16).lanes((short)1);
|
||||
} else if(dataType == UINT32) {
|
||||
return new DLDataType().code((byte)kDLUInt).bits((byte)32).lanes((short)1);
|
||||
} else if(dataType == UINT64) {
|
||||
return new DLDataType().code((byte)kDLUInt).bits((byte)64).lanes((short)1);
|
||||
} else if(dataType == FLOAT16) {
|
||||
return new DLDataType().code((byte)kDLFloat).bits((byte)16).lanes((short)1);
|
||||
} else if(dataType == FLOAT) {
|
||||
return new DLDataType().code((byte)kDLFloat).bits((byte)32).lanes((short)1);
|
||||
} else if(dataType == DOUBLE) {
|
||||
return new DLDataType().code((byte)kDLFloat).bits((byte)64).lanes((short)1);
|
||||
} else if(dataType == BFLOAT16) {
|
||||
return new DLDataType().code((byte)kDLBfloat).bits((byte)16).lanes((short)1);
|
||||
} else
|
||||
throw new IllegalArgumentException("Illegal data type " + dataType);
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert an tvm {@link DLTensor}
|
||||
* in to an {@link INDArray}
|
||||
* @param value the tensor to convert
|
||||
* @return
|
||||
*/
|
||||
public static INDArray getArray(DLTensor value) {
|
||||
DataType dataType = dataTypeForTvmType(value.dtype());
|
||||
LongPointer shape = value.shape();
|
||||
LongPointer stride = value.strides();
|
||||
long[] shapeConvert;
|
||||
if(shape != null) {
|
||||
shapeConvert = new long[value.ndim()];
|
||||
shape.get(shapeConvert);
|
||||
} else {
|
||||
shapeConvert = new long[]{1};
|
||||
}
|
||||
long[] strideConvert;
|
||||
if(stride != null) {
|
||||
strideConvert = new long[value.ndim()];
|
||||
stride.get(strideConvert);
|
||||
} else {
|
||||
strideConvert = Nd4j.getStrides(shapeConvert);
|
||||
}
|
||||
long size = 1;
|
||||
for (int i = 0; i < shapeConvert.length; i++) {
|
||||
size *= shapeConvert[i];
|
||||
}
|
||||
size *= value.dtype().bits() / 8;
|
||||
|
||||
DataBuffer getBuffer = getDataBuffer(value,size);
|
||||
Preconditions.checkState(dataType.equals(getBuffer.dataType()),"Data type must be equivalent as specified by the tvm metadata.");
|
||||
return Nd4j.create(getBuffer,shapeConvert,strideConvert,0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an tvm tensor from an ndarray.
|
||||
* @param ndArray the ndarray to get the value from
|
||||
* @param ctx the {@link DLContext} to use.
|
||||
* @return
|
||||
*/
|
||||
public static DLTensor getTensor(INDArray ndArray, DLContext ctx) {
|
||||
DLTensor ret = new DLTensor();
|
||||
ret.data(ndArray.data().pointer());
|
||||
ret.ctx(ctx);
|
||||
ret.ndim(ndArray.rank());
|
||||
ret.dtype(tvmTypeForDataType(ndArray.dataType()));
|
||||
ret.shape(new LongPointer(ndArray.shape()));
|
||||
ret.strides(new LongPointer(ndArray.stride()));
|
||||
ret.byte_offset(ndArray.offset());
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the data buffer from the given value
|
||||
* @param tens the values to get
|
||||
* @return the equivalent data buffer
|
||||
*/
|
||||
public static DataBuffer getDataBuffer(DLTensor tens, long size) {
|
||||
DataBuffer buffer = null;
|
||||
DataType type = dataTypeForTvmType(tens.dtype());
|
||||
switch (type) {
|
||||
case BYTE:
|
||||
BytePointer pInt8 = new BytePointer(tens.data()).capacity(size);
|
||||
Indexer int8Indexer = ByteIndexer.create(pInt8);
|
||||
buffer = Nd4j.createBuffer(pInt8, type, size, int8Indexer);
|
||||
break;
|
||||
case SHORT:
|
||||
ShortPointer pInt16 = new ShortPointer(tens.data()).capacity(size);
|
||||
Indexer int16Indexer = ShortIndexer.create(pInt16);
|
||||
buffer = Nd4j.createBuffer(pInt16, type, size, int16Indexer);
|
||||
break;
|
||||
case INT:
|
||||
IntPointer pInt32 = new IntPointer(tens.data()).capacity(size);
|
||||
Indexer int32Indexer = IntIndexer.create(pInt32);
|
||||
buffer = Nd4j.createBuffer(pInt32, type, size, int32Indexer);
|
||||
break;
|
||||
case LONG:
|
||||
LongPointer pInt64 = new LongPointer(tens.data()).capacity(size);
|
||||
Indexer int64Indexer = LongIndexer.create(pInt64);
|
||||
buffer = Nd4j.createBuffer(pInt64, type, size, int64Indexer);
|
||||
break;
|
||||
case UBYTE:
|
||||
BytePointer pUint8 = new BytePointer(tens.data()).capacity(size);
|
||||
Indexer uint8Indexer = UByteIndexer.create(pUint8);
|
||||
buffer = Nd4j.createBuffer(pUint8, type, size, uint8Indexer);
|
||||
break;
|
||||
case UINT16:
|
||||
ShortPointer pUint16 = new ShortPointer(tens.data()).capacity(size);
|
||||
Indexer uint16Indexer = UShortIndexer.create(pUint16);
|
||||
buffer = Nd4j.createBuffer(pUint16, type, size, uint16Indexer);
|
||||
break;
|
||||
case UINT32:
|
||||
IntPointer pUint32 = new IntPointer(tens.data()).capacity(size);
|
||||
Indexer uint32Indexer = UIntIndexer.create(pUint32);
|
||||
buffer = Nd4j.createBuffer(pUint32, type, size, uint32Indexer);
|
||||
break;
|
||||
case UINT64:
|
||||
LongPointer pUint64 = new LongPointer(tens.data()).capacity(size);
|
||||
Indexer uint64Indexer = LongIndexer.create(pUint64);
|
||||
buffer = Nd4j.createBuffer(pUint64, type, size, uint64Indexer);
|
||||
break;
|
||||
case HALF:
|
||||
ShortPointer pFloat16 = new ShortPointer(tens.data()).capacity(size);
|
||||
Indexer float16Indexer = HalfIndexer.create(pFloat16);
|
||||
buffer = Nd4j.createBuffer(pFloat16, type, size, float16Indexer);
|
||||
break;
|
||||
case FLOAT:
|
||||
FloatPointer pFloat = new FloatPointer(tens.data()).capacity(size);
|
||||
FloatIndexer floatIndexer = FloatIndexer.create(pFloat);
|
||||
buffer = Nd4j.createBuffer(pFloat, type, size, floatIndexer);
|
||||
break;
|
||||
case DOUBLE:
|
||||
DoublePointer pDouble = new DoublePointer(tens.data()).capacity(size);
|
||||
Indexer doubleIndexer = DoubleIndexer.create(pDouble);
|
||||
buffer = Nd4j.createBuffer(pDouble, type, size, doubleIndexer);
|
||||
break;
|
||||
case BFLOAT16:
|
||||
ShortPointer pBfloat16 = new ShortPointer(tens.data()).capacity(size);
|
||||
Indexer bfloat16Indexer = Bfloat16Indexer.create(pBfloat16);
|
||||
buffer = Nd4j.createBuffer(pBfloat16, type, size, bfloat16Indexer);
|
||||
break;
|
||||
default:
|
||||
throw new RuntimeException("Unsupported data type encountered");
|
||||
}
|
||||
return buffer;
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,101 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
package org.nd4j.tvm.runner;
|
||||
|
||||
import org.bytedeco.javacpp.*;
|
||||
import org.bytedeco.cpython.*;
|
||||
import org.bytedeco.numpy.*;
|
||||
import org.bytedeco.tvm.*;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.Arrays;
|
||||
import java.util.LinkedHashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.bytedeco.cpython.global.python.*;
|
||||
import static org.bytedeco.numpy.global.numpy.*;
|
||||
import static org.bytedeco.tvm.global.tvm_runtime.*;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class TvmRunnerTests {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
static void PrepareTestLibs(String libPath) throws Exception {
|
||||
Py_AddPath(org.bytedeco.tvm.presets.tvm.cachePackages());
|
||||
Py_Initialize();
|
||||
if (_import_array() < 0) {
|
||||
System.err.println("numpy.core.multiarray failed to import");
|
||||
PyErr_Print();
|
||||
System.exit(-1);
|
||||
}
|
||||
PyObject globals = PyModule_GetDict(PyImport_AddModule("__main__"));
|
||||
|
||||
PyRun_StringFlags("\"\"\"Script to prepare test_relay_add.so\"\"\"\n"
|
||||
+ "import tvm\n"
|
||||
+ "import numpy as np\n"
|
||||
+ "from tvm import relay\n"
|
||||
+ "import os\n"
|
||||
|
||||
+ "x = relay.var(\"x\", shape=(1, 1), dtype=\"float32\")\n"
|
||||
+ "y = relay.var(\"y\", shape=(1, 1), dtype=\"float32\")\n"
|
||||
+ "params = {\"y\": np.ones((1, 1), dtype=\"float32\")}\n"
|
||||
+ "mod = tvm.IRModule.from_expr(relay.Function([x, y], x + y))\n"
|
||||
+ "# build a module\n"
|
||||
+ "compiled_lib = relay.build(mod, tvm.target.create(\"llvm\"), params=params)\n"
|
||||
+ "# export it as a shared library\n"
|
||||
+ "dylib_path = os.path.join(\"" + libPath + "\", \"test_relay_add.so\")\n"
|
||||
+ "compiled_lib.export_library(dylib_path)\n",
|
||||
|
||||
Py_file_input, globals, globals, null);
|
||||
|
||||
if (PyErr_Occurred() != null) {
|
||||
System.err.println("Python error occurred");
|
||||
PyErr_Print();
|
||||
System.exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAdd() throws Exception {
|
||||
/* try to use MKL when available */
|
||||
System.setProperty("org.bytedeco.openblas.load", "mkl");
|
||||
|
||||
File libPath = testDir.newFolder("lib");
|
||||
PrepareTestLibs(libPath.getAbsolutePath().replace(File.separatorChar, '/'));
|
||||
File f = new File(libPath, "test_relay_add.so");
|
||||
INDArray x = Nd4j.scalar(1.0f).reshape(1,1);
|
||||
TvmRunner tvmRunner = TvmRunner.builder()
|
||||
.modelUri(f.getAbsolutePath())
|
||||
.build();
|
||||
Map<String,INDArray> inputs = new LinkedHashMap<>();
|
||||
inputs.put("x",x);
|
||||
Map<String, INDArray> exec = tvmRunner.exec(inputs);
|
||||
INDArray z = exec.get("0");
|
||||
assertEquals(2.0,z.sumNumber().doubleValue(),1e-1);
|
||||
}
|
||||
}
|
|
@ -46,6 +46,7 @@
|
|||
<module>nd4j-parameter-server-parent</module>
|
||||
<module>nd4j-tensorflow</module>
|
||||
<module>nd4j-onnxruntime</module>
|
||||
<module>nd4j-tvm</module>
|
||||
<module>nd4j-common-tests</module>
|
||||
<module>samediff-import</module>
|
||||
</modules>
|
||||
|
|
22
pom.xml
22
pom.xml
|
@ -185,26 +185,26 @@
|
|||
<javacpp.platform.extension/> <!-- -Djavacpp.platform.extension=-avx512 -->
|
||||
<javacpp.platform.properties>${javacpp.platform}</javacpp.platform.properties>
|
||||
|
||||
<javacpp.version>1.5.4</javacpp.version>
|
||||
<javacpp-presets.version>1.5.4</javacpp-presets.version>
|
||||
<javacv.version>1.5.4</javacv.version>
|
||||
<javacpp.version>1.5.5-SNAPSHOT</javacpp.version>
|
||||
<javacpp-presets.version>1.5.5-SNAPSHOT</javacpp-presets.version>
|
||||
<javacv.version>1.5.5-SNAPSHOT</javacv.version>
|
||||
|
||||
<python.version>3.7.9</python.version>
|
||||
<python.version>3.9.1</python.version>
|
||||
<cpython-platform.version>${python.version}-${javacpp-presets.version}</cpython-platform.version>
|
||||
<numpy.version>1.19.1</numpy.version>
|
||||
<numpy.version>1.20.1</numpy.version>
|
||||
<numpy.javacpp.version>${numpy.version}-${javacpp-presets.version}</numpy.javacpp.version>
|
||||
|
||||
<openblas.version>0.3.10</openblas.version>
|
||||
|
||||
<mkl.version>2020.3</mkl.version>
|
||||
<opencv.version>4.4.0</opencv.version>
|
||||
<openblas.version>0.3.13</openblas.version>
|
||||
<mkl.version>2021.1</mkl.version>
|
||||
<opencv.version>4.5.1</opencv.version>
|
||||
<ffmpeg.version>4.3.1</ffmpeg.version>
|
||||
<leptonica.version>1.80.0</leptonica.version>
|
||||
<hdf5.version>1.12.0</hdf5.version>
|
||||
<ale.version>0.6.1</ale.version>
|
||||
<gym.version>0.17.2</gym.version>
|
||||
<tensorflow.version>1.15.3</tensorflow.version>
|
||||
<gym.version>0.18.0</gym.version>
|
||||
<tensorflow.version>1.15.5</tensorflow.version>
|
||||
<tensorflow.javacpp.version>${tensorflow.version}-${javacpp-presets.version}</tensorflow.javacpp.version>
|
||||
|
||||
<archunit.version>0.14.1</archunit.version>
|
||||
<commons-compress.version>1.18</commons-compress.version>
|
||||
<commonsmath.version>3.5</commonsmath.version>
|
||||
|
|
Loading…
Reference in New Issue