[WIP] Last set of changes (#234)

* mmul op instead of cublasSgemm

Signed-off-by: raver119 <raver119@gmail.com>

* transB

Signed-off-by: raver119 <raver119@gmail.com>

* jcpp handles

Signed-off-by: raver119 <raver119@gmail.com>

* bitwise and/or/xor

Signed-off-by: raver119 <raver119@gmail.com>

* bitwise and/or/xor mapping

Signed-off-by: raver119 <raver119@gmail.com>

* cuda/cublas version check

Signed-off-by: raver119 <raver119@gmail.com>

* add expected version

Signed-off-by: raver119 <raver119@gmail.com>

* cuda/cublas version check in java

Signed-off-by: raver119 <raver119@gmail.com>

* one more error check

Signed-off-by: raver119 <raver119@gmail.com>

* build fix

Signed-off-by: raver119 <raver119@gmail.com>

* build fix

Signed-off-by: raver119 <raver119@gmail.com>

* build fix

Signed-off-by: raver119 <raver119@gmail.com>

* one more fix

Signed-off-by: raver119 <raver119@gmail.com>

* skip CUDA version check for now

Signed-off-by: raver119 <raver119@gmail.com>

* better wording

Signed-off-by: raver119 <raver119@gmail.com>

* few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* few more tweaks

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-09-04 14:41:08 +03:00 committed by GitHub
parent d41018751b
commit a90c7dd995
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 646 additions and 16 deletions

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SAMEDIFF_BLASVERSIONHELPER_H
#define SAMEDIFF_BLASVERSIONHELPER_H
#include <dll.h>
#include <cuda.h>
#include <cuda_runtime.h>
namespace nd4j {
class ND4J_EXPORT BlasVersionHelper {
public:
int _blasMajorVersion = 0;
int _blasMinorVersion = 0;
int _blasPatchVersion = 0;
BlasVersionHelper();
~BlasVersionHelper() = default;
};
}
#endif //DEV_TESTS_BLASVERSIONHELPER_H

View File

@ -253,20 +253,20 @@ if(CUDA_BLAS)
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
if (NOT BUILD_TESTS)
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
else()
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
endif()

View File

@ -35,7 +35,7 @@
#include <cuda.h>
#include <cuda_runtime.h>
#include "BlasVersionHelper.h"
#endif
namespace nd4j {
@ -66,6 +66,13 @@ namespace nd4j {
#endif
#ifdef __CUDABLAS__
BlasVersionHelper ver;
_blasMajorVersion = ver._blasMajorVersion;
_blasMinorVersion = ver._blasMinorVersion;
_blasPatchVersion = ver._blasPatchVersion;
printf("ND4J CUDA build version: %i.%i.%i\n", _blasMajorVersion, _blasMinorVersion, _blasPatchVersion);
fflush(stdout);
int devCnt = 0;
cudaGetDeviceCount(&devCnt);
auto devProperties = new cudaDeviceProp[devCnt];

View File

@ -56,6 +56,13 @@ namespace nd4j{
Environment();
~Environment();
public:
/**
* These 3 fields are mostly for CUDA/cuBLAS version tracking
*/
int _blasMajorVersion = 0;
int _blasMinorVersion = 0;
int _blasPatchVersion = 0;
static Environment* getInstance();
bool isVerbose();

View File

@ -647,7 +647,7 @@ ND4J_EXPORT void setOmpNumThreads(int threads);
ND4J_EXPORT void setOmpMinThreads(int threads);
ND4J_EXPORT bool isBlasVersionMatches(int major, int minor, int build);
/**
*

View File

@ -728,6 +728,10 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
}
}
bool isBlasVersionMatches(int major, int minor, int build) {
return true;
}
/**
*
* @param opNum

View File

@ -0,0 +1,29 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "../BlasVersionHelper.h"
namespace nd4j {
BlasVersionHelper::BlasVersionHelper() {
_blasMajorVersion = __CUDACC_VER_MAJOR__;
_blasMinorVersion = __CUDACC_VER_MINOR__;
_blasPatchVersion = __CUDACC_VER_BUILD__;
}
}

View File

@ -3357,6 +3357,18 @@ void deleteTadPack(nd4j::TadPack* ptr) {
delete ptr;
}
bool isBlasVersionMatches(int major, int minor, int build) {
auto result = major == Environment::getInstance()->_blasMajorVersion && minor == Environment::getInstance()->_blasMinorVersion && build == Environment::getInstance()->_blasPatchVersion;
if (!result) {
nd4j_printf("CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i instead\n", Environment::getInstance()->_blasMajorVersion, Environment::getInstance()->_blasMinorVersion, Environment::getInstance()->_blasPatchVersion, major, minor, build);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(152);
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("CUDA/cuBLAS version mismatch");
}
return result;
}
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
}

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_bitwise_and)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
BROADCASTABLE_OP_IMPL(bitwise_and, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), y, z, false);
return Status::OK();
}
DECLARE_TYPES(bitwise_and) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_bitwise_or)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
BROADCASTABLE_OP_IMPL(bitwise_or, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), y, z, false);
return Status::OK();
}
DECLARE_TYPES(bitwise_or) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_bitwise_xor)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/shift.h>
namespace nd4j {
namespace ops {
BROADCASTABLE_OP_IMPL(bitwise_xor, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto z = OUTPUT_VARIABLE(0);
BROADCAST_CHECK_EMPTY(x,y,z);
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), y, z, false);
return Status::OK();
}
DECLARE_TYPES(bitwise_xor) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS})
->setSameMode(true);
}
}
}
#endif

View File

@ -81,6 +81,39 @@ namespace nd4j {
DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0);
#endif
/**
* This operation applies bitwise AND
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_bitwise_and)
DECLARE_BROADCASTABLE_OP(bitwise_and, 0, 0);
#endif
/**
* This operation applies bitwise OR
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_bitwise_or)
DECLARE_BROADCASTABLE_OP(bitwise_or, 0, 0);
#endif
/**
* This operation applies bitwise XOR
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_bitwise_xor)
DECLARE_BROADCASTABLE_OP(bitwise_xor, 0, 0);
#endif
/**
* This operation returns hamming distance based on bits
*

View File

@ -353,6 +353,9 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class,
org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class,

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Bit-wise AND operation, broadcastable
*
* @author raver119@gmail.com
*/
public class BitwiseAnd extends BaseDynamicTransformOp {
public BitwiseAnd(SameDiff sameDiff, SDVariable x, SDVariable y) {
super(sameDiff, new SDVariable[] {x, y} ,false);
}
public BitwiseAnd(INDArray x, INDArray y, INDArray output) {
super(new INDArray[]{x, y}, new INDArray[]{output});
}
public BitwiseAnd(INDArray x, INDArray y) {
this(x, y,x.ulike());
}
public BitwiseAnd() {}
@Override
public String opName() {
return "bitwise_and";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
return "bitwise_and";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Bit-wise OR operation, broadcastable
*
* @author raver119@gmail.com
*/
public class BitwiseOr extends BaseDynamicTransformOp {
public BitwiseOr(SameDiff sameDiff, SDVariable x, SDVariable y) {
super(sameDiff, new SDVariable[] {x, y} ,false);
}
public BitwiseOr(INDArray x, INDArray y, INDArray output) {
super(new INDArray[]{x, y}, new INDArray[]{output});
}
public BitwiseOr(INDArray x, INDArray y) {
this(x, y,x.ulike());
}
public BitwiseOr() {}
@Override
public String opName() {
return "bitwise_or";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
return "bitwise_or";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.transforms.custom;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
import java.util.Collections;
import java.util.List;
/**
* Bit-wise XOR operation, broadcastable
*
* @author raver119@gmail.com
*/
public class BitwiseXor extends BaseDynamicTransformOp {
public BitwiseXor(SameDiff sameDiff, SDVariable x, SDVariable y) {
super(sameDiff, new SDVariable[] {x, y} ,false);
}
public BitwiseXor(INDArray x, INDArray y, INDArray output) {
super(new INDArray[]{x, y}, new INDArray[]{output});
}
public BitwiseXor(INDArray x, INDArray y) {
this(x, y,x.ulike());
}
public BitwiseXor() {}
@Override
public String opName() {
return "bitwise_xor";
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
return "bitwise_xor";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> i_v) {
throw new UnsupportedOperationException("Not yet implemented: " + opName());
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
return Collections.singletonList(dataTypes.get(0));
}
}

View File

@ -54,7 +54,7 @@ public class DeallocatorService {
deallocatorThreads = new Thread[numThreads];
queues = new ReferenceQueue[numThreads];
for (int e = 0; e < numThreads; e++) {
log.debug("Starting deallocator thread {}", e + 1);
log.trace("Starting deallocator thread {}", e + 1);
queues[e] = new ReferenceQueue<>();
int deviceId = e % numDevices;

View File

@ -1151,4 +1151,6 @@ public interface NativeOps {
int lastErrorCode();
String lastErrorMessage();
boolean isBlasVersionMatches(int major, int minor, int build);
}

View File

@ -101,7 +101,7 @@ public class NativeOpsHolder {
}
//deviceNativeOps.setOmpNumThreads(4);
log.info("Number of threads used for NativeOps: {}", deviceNativeOps.ompGetMaxThreads());
log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads());
} catch (Exception | Error e) {
throw new RuntimeException(
"ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html",

View File

@ -51,7 +51,8 @@ public abstract class Nd4jBlas implements Blas {
numThreads = NativeOpsHolder.getCores(Runtime.getRuntime().availableProcessors());
setMaxThreads(numThreads);
}
log.info("Number of threads used for BLAS: {}", getMaxThreads());
log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads());
}
}

View File

@ -52,6 +52,7 @@ public class JCublasBackend extends Nd4jBackend {
throw new RuntimeException("No CUDA devices were found in system");
}
Loader.load(org.bytedeco.cuda.global.cublas.class);
return true;
}

View File

@ -108,6 +108,22 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
/*
val major = new int[1];
val minor = new int[1];
val build = new int[1];
org.bytedeco.cuda.global.cublas.cublasGetProperty(0, major);
org.bytedeco.cuda.global.cublas.cublasGetProperty(1, minor);
org.bytedeco.cuda.global.cublas.cublasGetProperty(2, build);
val pew = new int[100];
org.bytedeco.cuda.global.cudart.cudaDriverGetVersion(pew);
nativeOps.isBlasVersionMatches(major[0], minor[0], build[0]);
if (nativeOps.lastErrorCode() != 0)
throw new RuntimeException(nativeOps.lastErrorMessage());
*/
}
@Override

View File

@ -28,10 +28,12 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator;
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
import org.nd4j.jita.conf.CudaEnvironment;
import org.nd4j.linalg.api.blas.impl.BaseLevel3;
import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.factory.DataTypeValidation;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.CublasPointer;
@ -113,16 +115,18 @@ public class JcublasLevel3 extends BaseLevel3 {
@Override
protected void sgemm(char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda,
INDArray B, int ldb, float beta, INDArray C, int ldc) {
//A = Shape.toOffsetZero(A);
//B = Shape.toOffsetZero(B);
/*
val ctx = AtomicAllocator.getInstance().getDeviceContext();
val handle = ctx.getCublasHandle();
synchronized (handle) {
Nd4j.exec(new Mmul(A, B, C, MMulTranspose.builder().transposeA(false).transposeB(false).build()));
}
*/
Nd4j.getExecutioner().push();
val ctx = allocator.getFlowController().prepareAction(C, A, B);
//log.info("Synchronizing CUDA stream");
ctx.getOldStream().synchronize();
val cAPointer = new CublasPointer(A, ctx);
val cBPointer = new CublasPointer(B, ctx);
val cCPointer = new CublasPointer(C, ctx);
@ -141,6 +145,7 @@ public class JcublasLevel3 extends BaseLevel3 {
}
allocator.registerAction(ctx, C, A, B);
OpExecutionerUtil.checkForAny(C);
}

View File

@ -557,6 +557,13 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public Environment(Pointer p) { super(p); }
/**
* These 3 fields are mostly for CUDA/cuBLAS version tracking
*/
public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter);
public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter);
public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter);
public static native Environment getInstance();
public native @Cast("bool") boolean isVerbose();
@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads);
public native void setOmpMinThreads(int threads);
public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build);
/**
*

View File

@ -557,6 +557,13 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public Environment(Pointer p) { super(p); }
/**
* These 3 fields are mostly for CUDA/cuBLAS version tracking
*/
public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter);
public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter);
public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter);
public static native Environment getInstance();
public native @Cast("bool") boolean isVerbose();
@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads);
public native void setOmpMinThreads(int threads);
public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build);
/**
*
@ -21929,6 +21936,78 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
}
// #endif
/**
* This operation applies bitwise AND
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_bitwise_and)
@Namespace("nd4j::ops") public static class bitwise_and extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public bitwise_and(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public bitwise_and(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public bitwise_and position(long position) {
return (bitwise_and)super.position(position);
}
public bitwise_and() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
/**
* This operation applies bitwise OR
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_bitwise_or)
@Namespace("nd4j::ops") public static class bitwise_or extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public bitwise_or(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public bitwise_or(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public bitwise_or position(long position) {
return (bitwise_or)super.position(position);
}
public bitwise_or() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
/**
* This operation applies bitwise XOR
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* \tparam T
*/
// #if NOT_EXCLUDED(OP_bitwise_xor)
@Namespace("nd4j::ops") public static class bitwise_xor extends BroadcastableOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public bitwise_xor(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public bitwise_xor(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public bitwise_xor position(long position) {
return (bitwise_xor)super.position(position);
}
public bitwise_xor() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif
/**
* This operation returns hamming distance based on bits
*