java cuda compilation fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-07 21:36:27 +03:00
parent 55066d9c41
commit 62a025439b
4 changed files with 29 additions and 15 deletions

View File

@ -168,7 +168,7 @@ if(CUDA_BLAS)
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70) list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static -O3 --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_70,code=sm_70)
endif() endif()
else() else()
list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -arch=compute_${COMPUTE} -code=sm_${COMPUTE}) list(APPEND CUDA_NVCC_FLAGS -DCUDA_10 ${EXPM} -w --cudart=static --expt-extended-lambda -O3 -Xfatbin -compress-all -arch=compute_${COMPUTE} -code=sm_${COMPUTE})
endif() endif()
elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9 elseif(CUDA_VERSION VERSION_GREATER "8.0") # cuda 9
if ("${COMPUTE}" STREQUAL "all") if ("${COMPUTE}" STREQUAL "all")

View File

@ -280,27 +280,27 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
@Override @Override
public INDArray create(long[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) { public INDArray create(long[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType, workspace), shape, stride, order, dataType); return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType), shape, stride, order, dataType);
} }
@Override @Override
public INDArray create(int[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) { public INDArray create(int[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType, workspace), shape, stride, order, dataType); return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType), shape, stride, order, dataType);
} }
@Override @Override
public INDArray create(short[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) { public INDArray create(short[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType, workspace), shape, stride, order, dataType); return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType), shape, stride, order, dataType);
} }
@Override @Override
public INDArray create(byte[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) { public INDArray create(byte[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType, workspace), shape, stride, order, dataType); return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType), shape, stride, order, dataType);
} }
@Override @Override
public INDArray create(boolean[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) { public INDArray create(boolean[] data, long[] shape, long[] stride, char order, DataType dataType, MemoryWorkspace workspace) {
return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType, workspace), shape, stride, order, dataType); return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType), shape, stride, order, dataType);
} }
@Override @Override
@ -1593,27 +1593,27 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
@Override @Override
public INDArray create(long[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) { public INDArray create(long[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType, workspace), shape, stride, Nd4j.order(), dataType); return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType), shape, stride, Nd4j.order(), dataType);
} }
@Override @Override
public INDArray create(int[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) { public INDArray create(int[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType, workspace), shape, stride, Nd4j.order(), dataType); return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType), shape, stride, Nd4j.order(), dataType);
} }
@Override @Override
public INDArray create(short[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) { public INDArray create(short[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType, workspace), shape, stride, Nd4j.order(), dataType); return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType), shape, stride, Nd4j.order(), dataType);
} }
@Override @Override
public INDArray create(byte[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) { public INDArray create(byte[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType, workspace), shape, stride, Nd4j.order(), dataType); return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType), shape, stride, Nd4j.order(), dataType);
} }
@Override @Override
public INDArray create(boolean[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) { public INDArray create(boolean[] data, long[] shape, long[] stride, DataType dataType, MemoryWorkspace workspace) {
return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType, workspace), shape, stride, Nd4j.order(), dataType); return new JCublasNDArray(Nd4j.createTypedBuffer(data, dataType), shape, stride, Nd4j.order(), dataType);
} }
@Override @Override

View File

@ -3814,9 +3814,12 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p
public native void syncToDevice(); public native void syncToDevice();
public native void syncShape(); public native void syncShape();
// #ifndef __JAVACPP_HACK__ /**
* This method can be used on architectures that use special buffers
* @param writeList
* @param readList
*/
// #endif
/** /**
* This method returns buffer pointer offset by given number of elements, wrt own data type * This method returns buffer pointer offset by given number of elements, wrt own data type
@ -5107,12 +5110,16 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
// #ifndef __JAVACPP_HACK__ // #ifndef __JAVACPP_HACK__
// #endif // #endif

View File

@ -3814,9 +3814,12 @@ public native @Cast("char*") String runFullBenchmarkSuit(@Cast("bool") boolean p
public native void syncToDevice(); public native void syncToDevice();
public native void syncShape(); public native void syncShape();
// #ifndef __JAVACPP_HACK__ /**
* This method can be used on architectures that use special buffers
* @param writeList
* @param readList
*/
// #endif
/** /**
* This method returns buffer pointer offset by given number of elements, wrt own data type * This method returns buffer pointer offset by given number of elements, wrt own data type
@ -5107,12 +5110,16 @@ NDArray& NDArray::operator()(const Nd4jLong* idx) {
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////
// #ifndef __JAVACPP_HACK__ // #ifndef __JAVACPP_HACK__
// #endif // #endif