[WIP] Last set of changes (#234)
* mmul op instead of cublasSgemm Signed-off-by: raver119 <raver119@gmail.com> * transB Signed-off-by: raver119 <raver119@gmail.com> * jcpp handles Signed-off-by: raver119 <raver119@gmail.com> * bitwise and/or/xor Signed-off-by: raver119 <raver119@gmail.com> * bitwise and/or/xor mapping Signed-off-by: raver119 <raver119@gmail.com> * cuda/cublas version check Signed-off-by: raver119 <raver119@gmail.com> * add expected version Signed-off-by: raver119 <raver119@gmail.com> * cuda/cublas version check in java Signed-off-by: raver119 <raver119@gmail.com> * one more error check Signed-off-by: raver119 <raver119@gmail.com> * build fix Signed-off-by: raver119 <raver119@gmail.com> * build fix Signed-off-by: raver119 <raver119@gmail.com> * build fix Signed-off-by: raver119 <raver119@gmail.com> * one more fix Signed-off-by: raver119 <raver119@gmail.com> * skip CUDA version check for now Signed-off-by: raver119 <raver119@gmail.com> * better wording Signed-off-by: raver119 <raver119@gmail.com> * few more tweaks Signed-off-by: raver119 <raver119@gmail.com> * few more tweaks Signed-off-by: raver119 <raver119@gmail.com>master
parent
d41018751b
commit
a90c7dd995
|
@ -0,0 +1,40 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef SAMEDIFF_BLASVERSIONHELPER_H
|
||||
#define SAMEDIFF_BLASVERSIONHELPER_H
|
||||
|
||||
#include <dll.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace nd4j {
|
||||
class ND4J_EXPORT BlasVersionHelper {
|
||||
public:
|
||||
int _blasMajorVersion = 0;
|
||||
int _blasMinorVersion = 0;
|
||||
int _blasPatchVersion = 0;
|
||||
|
||||
BlasVersionHelper();
|
||||
~BlasVersionHelper() = default;
|
||||
};
|
||||
}
|
||||
|
||||
#endif //DEV_TESTS_BLASVERSIONHELPER_H
|
|
@ -253,20 +253,20 @@ if(CUDA_BLAS)
|
|||
file(GLOB_RECURSE LOOPS_SOURCES_CUDA false ../include/loops/*.cu)
|
||||
|
||||
if (NOT BUILD_TESTS)
|
||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
|
||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} SHARED cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
|
||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
||||
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
|
||||
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
|
||||
else()
|
||||
set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DBUILD_TESTS=true")
|
||||
|
||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu ${LOOPS_SOURCES_CUDA}
|
||||
CUDA_ADD_LIBRARY(${LIBND4J_NAME} STATIC cuda/NativeOps.cu cuda/NativeOpExecutioner.cu cuda/BlasVersionHelper.cu Environment.cpp ${LOOPS_SOURCES_CUDA}
|
||||
${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES}
|
||||
../include/cnpy/cnpy.cpp ../include/nd4jmemset.h ../include/nd4jmalloc.h
|
||||
cpu/GraphExecutioner.cpp cuda/NDArray.cu cpu/NDArrayFactory.cpp
|
||||
Environment.cpp Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
Environment.h ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES}
|
||||
${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES})
|
||||
endif()
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@
|
|||
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
#include "BlasVersionHelper.h"
|
||||
#endif
|
||||
|
||||
namespace nd4j {
|
||||
|
@ -66,6 +66,13 @@ namespace nd4j {
|
|||
#endif
|
||||
|
||||
#ifdef __CUDABLAS__
|
||||
BlasVersionHelper ver;
|
||||
_blasMajorVersion = ver._blasMajorVersion;
|
||||
_blasMinorVersion = ver._blasMinorVersion;
|
||||
_blasPatchVersion = ver._blasPatchVersion;
|
||||
printf("ND4J CUDA build version: %i.%i.%i\n", _blasMajorVersion, _blasMinorVersion, _blasPatchVersion);
|
||||
fflush(stdout);
|
||||
|
||||
int devCnt = 0;
|
||||
cudaGetDeviceCount(&devCnt);
|
||||
auto devProperties = new cudaDeviceProp[devCnt];
|
||||
|
|
|
@ -56,6 +56,13 @@ namespace nd4j{
|
|||
Environment();
|
||||
~Environment();
|
||||
public:
|
||||
/**
|
||||
* These 3 fields are mostly for CUDA/cuBLAS version tracking
|
||||
*/
|
||||
int _blasMajorVersion = 0;
|
||||
int _blasMinorVersion = 0;
|
||||
int _blasPatchVersion = 0;
|
||||
|
||||
static Environment* getInstance();
|
||||
|
||||
bool isVerbose();
|
||||
|
|
|
@ -647,7 +647,7 @@ ND4J_EXPORT void setOmpNumThreads(int threads);
|
|||
ND4J_EXPORT void setOmpMinThreads(int threads);
|
||||
|
||||
|
||||
|
||||
ND4J_EXPORT bool isBlasVersionMatches(int major, int minor, int build);
|
||||
|
||||
/**
|
||||
*
|
||||
|
|
|
@ -728,6 +728,10 @@ void execReduce3Tad(Nd4jPointer *extraPointers,
|
|||
}
|
||||
}
|
||||
|
||||
bool isBlasVersionMatches(int major, int minor, int build) {
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param opNum
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include "../BlasVersionHelper.h"
|
||||
|
||||
namespace nd4j {
|
||||
BlasVersionHelper::BlasVersionHelper() {
|
||||
_blasMajorVersion = __CUDACC_VER_MAJOR__;
|
||||
_blasMinorVersion = __CUDACC_VER_MINOR__;
|
||||
_blasPatchVersion = __CUDACC_VER_BUILD__;
|
||||
}
|
||||
}
|
|
@ -3357,6 +3357,18 @@ void deleteTadPack(nd4j::TadPack* ptr) {
|
|||
delete ptr;
|
||||
}
|
||||
|
||||
bool isBlasVersionMatches(int major, int minor, int build) {
|
||||
auto result = major == Environment::getInstance()->_blasMajorVersion && minor == Environment::getInstance()->_blasMinorVersion && build == Environment::getInstance()->_blasPatchVersion;
|
||||
|
||||
if (!result) {
|
||||
nd4j_printf("CUDA/cuBLAS version mismatch. Expected: %i.%i.%i but got %i.%i.%i instead\n", Environment::getInstance()->_blasMajorVersion, Environment::getInstance()->_blasMinorVersion, Environment::getInstance()->_blasPatchVersion, major, minor, build);
|
||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorCode(152);
|
||||
nd4j::LaunchContext::defaultContext()->errorReference()->setErrorMessage("CUDA/cuBLAS version mismatch");
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
nd4j::ConstantDataBuffer* constantBufferLong(nd4j::DataType dtype, Nd4jLong *data, int length) {
|
||||
return nd4j::ConstantHelper::getInstance()->constantBuffer(ConstantDescriptor(data, length), dtype);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_bitwise_and)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/shift.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(bitwise_and, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntAnd, pairwise::IntOps::IntAnd, broadcast::IntOps::IntAnd), y, z, false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(bitwise_and) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,50 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_bitwise_or)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/shift.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(bitwise_or, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntOr, pairwise::IntOps::IntOr, broadcast::IntOps::IntOr), y, z, false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(bitwise_or) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -0,0 +1,50 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_bitwise_xor)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/shift.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
BROADCASTABLE_OP_IMPL(bitwise_xor, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
BROADCAST_CHECK_EMPTY(x,y,z);
|
||||
|
||||
x->applyTrueBroadcast(BroadcastIntOpsTuple::custom(scalar::IntOps::IntXor, pairwise::IntOps::IntXor, broadcast::IntOps::IntXor), y, z, false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(bitwise_xor) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_INTS})
|
||||
->setSameMode(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -81,6 +81,39 @@ namespace nd4j {
|
|||
DECLARE_BROADCASTABLE_OP(cyclic_rshift_bits, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation applies bitwise AND
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_bitwise_and)
|
||||
DECLARE_BROADCASTABLE_OP(bitwise_and, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation applies bitwise OR
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_bitwise_or)
|
||||
DECLARE_BROADCASTABLE_OP(bitwise_or, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation applies bitwise XOR
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_bitwise_xor)
|
||||
DECLARE_BROADCASTABLE_OP(bitwise_xor, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation returns hamming distance based on bits
|
||||
*
|
||||
|
|
|
@ -353,6 +353,9 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseAnd.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseXor.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.BitwiseOr.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class,
|
||||
org.nd4j.linalg.api.ops.impl.transforms.custom.Dilation2D.class,
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Bit-wise AND operation, broadcastable
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class BitwiseAnd extends BaseDynamicTransformOp {
|
||||
|
||||
public BitwiseAnd(SameDiff sameDiff, SDVariable x, SDVariable y) {
|
||||
super(sameDiff, new SDVariable[] {x, y} ,false);
|
||||
}
|
||||
|
||||
public BitwiseAnd(INDArray x, INDArray y, INDArray output) {
|
||||
super(new INDArray[]{x, y}, new INDArray[]{output});
|
||||
}
|
||||
|
||||
public BitwiseAnd(INDArray x, INDArray y) {
|
||||
this(x, y,x.ulike());
|
||||
}
|
||||
|
||||
public BitwiseAnd() {}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "bitwise_and";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "bitwise_and";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Bit-wise OR operation, broadcastable
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class BitwiseOr extends BaseDynamicTransformOp {
|
||||
|
||||
public BitwiseOr(SameDiff sameDiff, SDVariable x, SDVariable y) {
|
||||
super(sameDiff, new SDVariable[] {x, y} ,false);
|
||||
}
|
||||
|
||||
public BitwiseOr(INDArray x, INDArray y, INDArray output) {
|
||||
super(new INDArray[]{x, y}, new INDArray[]{output});
|
||||
}
|
||||
|
||||
public BitwiseOr(INDArray x, INDArray y) {
|
||||
this(x, y,x.ulike());
|
||||
}
|
||||
|
||||
public BitwiseOr() {}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "bitwise_or";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "bitwise_or";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.transforms.custom;
|
||||
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.imports.NoOpNameFoundException;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Bit-wise XOR operation, broadcastable
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class BitwiseXor extends BaseDynamicTransformOp {
|
||||
|
||||
public BitwiseXor(SameDiff sameDiff, SDVariable x, SDVariable y) {
|
||||
super(sameDiff, new SDVariable[] {x, y} ,false);
|
||||
}
|
||||
|
||||
public BitwiseXor(INDArray x, INDArray y, INDArray output) {
|
||||
super(new INDArray[]{x, y}, new INDArray[]{output});
|
||||
}
|
||||
|
||||
public BitwiseXor(INDArray x, INDArray y) {
|
||||
this(x, y,x.ulike());
|
||||
}
|
||||
|
||||
public BitwiseXor() {}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "bitwise_xor";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public String onnxName() {
|
||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "bitwise_xor";
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public List<SDVariable> doDiff(List<SDVariable> i_v) {
|
||||
throw new UnsupportedOperationException("Not yet implemented: " + opName());
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> dataTypes){
|
||||
Preconditions.checkState(dataTypes.get(0).isIntType(), "Input 0 datatype must be a integer type, got %s", dataTypes.get(0));
|
||||
return Collections.singletonList(dataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -54,7 +54,7 @@ public class DeallocatorService {
|
|||
deallocatorThreads = new Thread[numThreads];
|
||||
queues = new ReferenceQueue[numThreads];
|
||||
for (int e = 0; e < numThreads; e++) {
|
||||
log.debug("Starting deallocator thread {}", e + 1);
|
||||
log.trace("Starting deallocator thread {}", e + 1);
|
||||
queues[e] = new ReferenceQueue<>();
|
||||
|
||||
int deviceId = e % numDevices;
|
||||
|
|
|
@ -1151,4 +1151,6 @@ public interface NativeOps {
|
|||
|
||||
int lastErrorCode();
|
||||
String lastErrorMessage();
|
||||
|
||||
boolean isBlasVersionMatches(int major, int minor, int build);
|
||||
}
|
||||
|
|
|
@ -101,7 +101,7 @@ public class NativeOpsHolder {
|
|||
}
|
||||
//deviceNativeOps.setOmpNumThreads(4);
|
||||
|
||||
log.info("Number of threads used for NativeOps: {}", deviceNativeOps.ompGetMaxThreads());
|
||||
log.info("Number of threads used for OpenMP: {}", deviceNativeOps.ompGetMaxThreads());
|
||||
} catch (Exception | Error e) {
|
||||
throw new RuntimeException(
|
||||
"ND4J is probably missing dependencies. For more information, please refer to: http://nd4j.org/getstarted.html",
|
||||
|
|
|
@ -51,7 +51,8 @@ public abstract class Nd4jBlas implements Blas {
|
|||
numThreads = NativeOpsHolder.getCores(Runtime.getRuntime().availableProcessors());
|
||||
setMaxThreads(numThreads);
|
||||
}
|
||||
log.info("Number of threads used for BLAS: {}", getMaxThreads());
|
||||
|
||||
log.info("Number of threads used for OpenMP BLAS: {}", getMaxThreads());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -52,6 +52,7 @@ public class JCublasBackend extends Nd4jBackend {
|
|||
throw new RuntimeException("No CUDA devices were found in system");
|
||||
}
|
||||
Loader.load(org.bytedeco.cuda.global.cublas.class);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -108,6 +108,22 @@ public class JCublasNDArrayFactory extends BaseNativeNDArrayFactory {
|
|||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
/*
|
||||
val major = new int[1];
|
||||
val minor = new int[1];
|
||||
val build = new int[1];
|
||||
org.bytedeco.cuda.global.cublas.cublasGetProperty(0, major);
|
||||
org.bytedeco.cuda.global.cublas.cublasGetProperty(1, minor);
|
||||
org.bytedeco.cuda.global.cublas.cublasGetProperty(2, build);
|
||||
|
||||
val pew = new int[100];
|
||||
org.bytedeco.cuda.global.cudart.cudaDriverGetVersion(pew);
|
||||
|
||||
nativeOps.isBlasVersionMatches(major[0], minor[0], build[0]);
|
||||
|
||||
if (nativeOps.lastErrorCode() != 0)
|
||||
throw new RuntimeException(nativeOps.lastErrorMessage());
|
||||
*/
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -28,10 +28,12 @@ import org.nd4j.jita.allocator.impl.AtomicAllocator;
|
|||
import org.nd4j.jita.allocator.pointers.cuda.cublasHandle_t;
|
||||
import org.nd4j.jita.conf.CudaEnvironment;
|
||||
import org.nd4j.linalg.api.blas.impl.BaseLevel3;
|
||||
import org.nd4j.linalg.api.blas.params.MMulTranspose;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.executioner.OpExecutionerUtil;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||
import org.nd4j.linalg.factory.DataTypeValidation;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.jcublas.CublasPointer;
|
||||
|
@ -113,16 +115,18 @@ public class JcublasLevel3 extends BaseLevel3 {
|
|||
@Override
|
||||
protected void sgemm(char Order, char TransA, char TransB, int M, int N, int K, float alpha, INDArray A, int lda,
|
||||
INDArray B, int ldb, float beta, INDArray C, int ldc) {
|
||||
//A = Shape.toOffsetZero(A);
|
||||
//B = Shape.toOffsetZero(B);
|
||||
/*
|
||||
val ctx = AtomicAllocator.getInstance().getDeviceContext();
|
||||
val handle = ctx.getCublasHandle();
|
||||
synchronized (handle) {
|
||||
Nd4j.exec(new Mmul(A, B, C, MMulTranspose.builder().transposeA(false).transposeB(false).build()));
|
||||
}
|
||||
*/
|
||||
|
||||
Nd4j.getExecutioner().push();
|
||||
|
||||
val ctx = allocator.getFlowController().prepareAction(C, A, B);
|
||||
|
||||
//log.info("Synchronizing CUDA stream");
|
||||
ctx.getOldStream().synchronize();
|
||||
|
||||
val cAPointer = new CublasPointer(A, ctx);
|
||||
val cBPointer = new CublasPointer(B, ctx);
|
||||
val cCPointer = new CublasPointer(C, ctx);
|
||||
|
@ -141,6 +145,7 @@ public class JcublasLevel3 extends BaseLevel3 {
|
|||
}
|
||||
|
||||
allocator.registerAction(ctx, C, A, B);
|
||||
|
||||
OpExecutionerUtil.checkForAny(C);
|
||||
}
|
||||
|
||||
|
|
|
@ -557,6 +557,13 @@ public class Nd4jCuda extends org.nd4j.nativeblas.Nd4jCudaHelper {
|
|||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public Environment(Pointer p) { super(p); }
|
||||
|
||||
/**
|
||||
* These 3 fields are mostly for CUDA/cuBLAS version tracking
|
||||
*/
|
||||
public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter);
|
||||
public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter);
|
||||
public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter);
|
||||
|
||||
public static native Environment getInstance();
|
||||
|
||||
public native @Cast("bool") boolean isVerbose();
|
||||
|
@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads);
|
|||
public native void setOmpMinThreads(int threads);
|
||||
|
||||
|
||||
|
||||
public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build);
|
||||
|
||||
/**
|
||||
*
|
||||
|
|
|
@ -557,6 +557,13 @@ public class Nd4jCpu extends org.nd4j.nativeblas.Nd4jCpuHelper {
|
|||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public Environment(Pointer p) { super(p); }
|
||||
|
||||
/**
|
||||
* These 3 fields are mostly for CUDA/cuBLAS version tracking
|
||||
*/
|
||||
public native int _blasMajorVersion(); public native Environment _blasMajorVersion(int setter);
|
||||
public native int _blasMinorVersion(); public native Environment _blasMinorVersion(int setter);
|
||||
public native int _blasPatchVersion(); public native Environment _blasPatchVersion(int setter);
|
||||
|
||||
public static native Environment getInstance();
|
||||
|
||||
public native @Cast("bool") boolean isVerbose();
|
||||
|
@ -1874,7 +1881,7 @@ public native void setOmpNumThreads(int threads);
|
|||
public native void setOmpMinThreads(int threads);
|
||||
|
||||
|
||||
|
||||
public native @Cast("bool") boolean isBlasVersionMatches(int major, int minor, int build);
|
||||
|
||||
/**
|
||||
*
|
||||
|
@ -21929,6 +21936,78 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation applies bitwise AND
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_bitwise_and)
|
||||
@Namespace("nd4j::ops") public static class bitwise_and extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public bitwise_and(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public bitwise_and(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public bitwise_and position(long position) {
|
||||
return (bitwise_and)super.position(position);
|
||||
}
|
||||
|
||||
public bitwise_and() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation applies bitwise OR
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_bitwise_or)
|
||||
@Namespace("nd4j::ops") public static class bitwise_or extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public bitwise_or(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public bitwise_or(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public bitwise_or position(long position) {
|
||||
return (bitwise_or)super.position(position);
|
||||
}
|
||||
|
||||
public bitwise_or() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation applies bitwise XOR
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* \tparam T
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_bitwise_xor)
|
||||
@Namespace("nd4j::ops") public static class bitwise_xor extends BroadcastableOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public bitwise_xor(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public bitwise_xor(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public bitwise_xor position(long position) {
|
||||
return (bitwise_xor)super.position(position);
|
||||
}
|
||||
|
||||
public bitwise_xor() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation returns hamming distance based on bits
|
||||
*
|
||||
|
|
Loading…
Reference in New Issue