cavis/libnd4j/include/array/ArrayOptions.h

373 lines
13 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef ND4J_ARRAY_OPTIONS_H
#define ND4J_ARRAY_OPTIONS_H
#include <op_boilerplate.h>
#include <pointercast.h>
#include <dll.h>
#include <array/DataType.h>
#include <array/ArrayType.h>
#include <array/SpaceType.h>
#include <array/SparseType.h>
#include <initializer_list>
#define ARRAY_SPARSE 2
#define ARRAY_COMPRESSED 4
#define ARRAY_EMPTY 8
#define ARRAY_RAGGED 16
#define ARRAY_CSR 32
#define ARRAY_CSC 64
#define ARRAY_COO 128
// complex values
#define ARRAY_COMPLEX 512
// quantized values
#define ARRAY_QUANTIZED 1024
// 16 bit float FP16
#define ARRAY_HALF 4096
// 16 bit bfloat16
#define ARRAY_BHALF 2048
// regular 32 bit float
#define ARRAY_FLOAT 8192
// regular 64 bit float
#define ARRAY_DOUBLE 16384
// 8 bit integer
#define ARRAY_CHAR 32768
// 16 bit integer
#define ARRAY_SHORT 65536
// 32 bit integer
#define ARRAY_INT 131072
// 64 bit integer
#define ARRAY_LONG 262144
// boolean values
#define ARRAY_BOOL 524288
// UTF values
#define ARRAY_UTF8 1048576
#define ARRAY_UTF16 4194304
#define ARRAY_UTF32 16777216
// flag for extras
#define ARRAY_EXTRAS 2097152
// flag for signed/unsigned integers
#define ARRAY_UNSIGNED 8388608
namespace nd4j {
class ND4J_EXPORT ArrayOptions {
private:
static FORCEINLINE _CUDA_HD Nd4jLong& extra(Nd4jLong* shape);
public:
static FORCEINLINE _CUDA_HD bool isNewFormat(const Nd4jLong *shapeInfo);
static FORCEINLINE _CUDA_HD bool hasPropertyBitSet(const Nd4jLong *shapeInfo, int property);
static FORCEINLINE _CUDA_HD bool togglePropertyBit(Nd4jLong *shapeInfo, int property);
static FORCEINLINE _CUDA_HD void unsetPropertyBit(Nd4jLong *shapeInfo, int property);
static FORCEINLINE _CUDA_HD void setPropertyBit(Nd4jLong *shapeInfo, int property);
static FORCEINLINE _CUDA_HD void setPropertyBits(Nd4jLong *shapeInfo, std::initializer_list<int> properties);
static FORCEINLINE _CUDA_HD bool isSparseArray(Nd4jLong *shapeInfo);
static FORCEINLINE _CUDA_HD bool isUnsigned(Nd4jLong *shapeInfo);
static FORCEINLINE _CUDA_HD nd4j::DataType dataType(const Nd4jLong *shapeInfo);
static FORCEINLINE _CUDA_HD SpaceType spaceType(Nd4jLong *shapeInfo);
static FORCEINLINE _CUDA_HD SpaceType spaceType(const Nd4jLong *shapeInfo);
static FORCEINLINE _CUDA_HD ArrayType arrayType(Nd4jLong *shapeInfo);
static FORCEINLINE _CUDA_HD ArrayType arrayType(const Nd4jLong *shapeInfo);
static FORCEINLINE _CUDA_HD SparseType sparseType(Nd4jLong *shapeInfo);
static FORCEINLINE _CUDA_HD SparseType sparseType(const Nd4jLong *shapeInfo);
static FORCEINLINE _CUDA_HD bool hasExtraProperties(Nd4jLong *shapeInfo);
static FORCEINLINE _CUDA_HD void resetDataType(Nd4jLong *shapeInfo);
static FORCEINLINE _CUDA_HD void setDataType(Nd4jLong *shapeInfo, const nd4j::DataType dataType);
static FORCEINLINE _CUDA_HD void copyDataType(Nd4jLong* to, const Nd4jLong* from);
};
FORCEINLINE _CUDA_HD Nd4jLong& ArrayOptions::extra(Nd4jLong* shape) {
return shape[shape[0] + shape[0] + 1];
}
FORCEINLINE _CUDA_HD bool ArrayOptions::isNewFormat(const Nd4jLong *shapeInfo) {
return (extra(const_cast<Nd4jLong*>(shapeInfo)) != 0);
}
FORCEINLINE _CUDA_HD bool ArrayOptions::isSparseArray(Nd4jLong *shapeInfo) {
return hasPropertyBitSet(shapeInfo, ARRAY_SPARSE);
}
FORCEINLINE _CUDA_HD bool ArrayOptions::hasExtraProperties(Nd4jLong *shapeInfo) {
return hasPropertyBitSet(shapeInfo, ARRAY_EXTRAS);
}
FORCEINLINE _CUDA_HD bool ArrayOptions::hasPropertyBitSet(const Nd4jLong *shapeInfo, int property) {
if (!isNewFormat(shapeInfo))
return false;
return ((extra(const_cast<Nd4jLong*>(shapeInfo)) & property) == property);
}
FORCEINLINE _CUDA_HD bool ArrayOptions::isUnsigned(Nd4jLong *shapeInfo) {
if (!isNewFormat(shapeInfo))
return false;
return hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED);
}
FORCEINLINE _CUDA_HD nd4j::DataType ArrayOptions::dataType(const Nd4jLong *shapeInfo) {
/*if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED))
return nd4j::DataType::QINT8;
else */if (hasPropertyBitSet(shapeInfo, ARRAY_FLOAT))
return nd4j::DataType::FLOAT32;
else if (hasPropertyBitSet(shapeInfo, ARRAY_DOUBLE))
return nd4j::DataType::DOUBLE;
else if (hasPropertyBitSet(shapeInfo, ARRAY_HALF))
return nd4j::DataType::HALF;
else if (hasPropertyBitSet(shapeInfo, ARRAY_BHALF))
return nd4j::DataType::BFLOAT16;
else if (hasPropertyBitSet(shapeInfo, ARRAY_BOOL))
return nd4j::DataType ::BOOL;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED)) {
if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR))
return nd4j::DataType ::UINT8;
else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT))
return nd4j::DataType ::UINT16;
else if (hasPropertyBitSet(shapeInfo, ARRAY_INT))
return nd4j::DataType ::UINT32;
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
return nd4j::DataType ::UINT64;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
return nd4j::DataType ::UTF8;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
return nd4j::DataType ::UTF16;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
return nd4j::DataType ::UTF32;
else {
//shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
#ifndef __CUDA_ARCH__
throw std::runtime_error("Bad datatype A");
#endif
}
}
else if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR))
return nd4j::DataType::INT8;
else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT))
return nd4j::DataType::INT16;
else if (hasPropertyBitSet(shapeInfo, ARRAY_INT))
return nd4j::DataType::INT32;
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
return nd4j::DataType::INT64;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
return nd4j::DataType::UTF8;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
return nd4j::DataType::UTF16;
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
return nd4j::DataType::UTF32;
else {
//shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
#ifndef __CUDA_ARCH__
throw std::runtime_error("Bad datatype B");
#endif
}
}
FORCEINLINE _CUDA_HD SpaceType ArrayOptions::spaceType(const Nd4jLong *shapeInfo) {
return spaceType(const_cast<Nd4jLong *>(shapeInfo));
}
FORCEINLINE _CUDA_HD SpaceType ArrayOptions::spaceType(Nd4jLong *shapeInfo) {
if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED))
return SpaceType::QUANTIZED;
if (hasPropertyBitSet(shapeInfo, ARRAY_COMPLEX))
return SpaceType::COMPLEX;
else // by default we return continuous type here
return SpaceType::CONTINUOUS;
}
FORCEINLINE _CUDA_HD ArrayType ArrayOptions::arrayType(const Nd4jLong *shapeInfo) {
return arrayType(const_cast<Nd4jLong *>(shapeInfo));
}
FORCEINLINE _CUDA_HD ArrayType ArrayOptions::arrayType(Nd4jLong *shapeInfo) {
if (hasPropertyBitSet(shapeInfo, ARRAY_SPARSE))
return ArrayType::SPARSE;
else if (hasPropertyBitSet(shapeInfo, ARRAY_COMPRESSED))
return ArrayType::COMPRESSED;
else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY))
return ArrayType::EMPTY;
else if (hasPropertyBitSet(shapeInfo, ARRAY_RAGGED))
return ArrayType::RAGGED;
else // by default we return DENSE type here
return ArrayType::DENSE;
}
FORCEINLINE _CUDA_HD bool ArrayOptions::togglePropertyBit(Nd4jLong *shapeInfo, int property) {
extra(shapeInfo) ^= property;
return hasPropertyBitSet(shapeInfo, property);
}
FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBit(Nd4jLong *shapeInfo, int property) {
extra(shapeInfo) |= property;
}
FORCEINLINE _CUDA_HD void ArrayOptions::unsetPropertyBit(Nd4jLong *shapeInfo, int property) {
extra(shapeInfo) &= property;
}
FORCEINLINE _CUDA_HD SparseType ArrayOptions::sparseType(const Nd4jLong *shapeInfo) {
return sparseType(const_cast<Nd4jLong *>(shapeInfo));
}
FORCEINLINE _CUDA_HD SparseType ArrayOptions::sparseType(Nd4jLong *shapeInfo) {
#ifndef __CUDA_ARCH__
if (!isSparseArray(shapeInfo))
throw std::runtime_error("Not a sparse array");
#endif
if (hasPropertyBitSet(shapeInfo, ARRAY_CSC))
return SparseType::CSC;
else if (hasPropertyBitSet(shapeInfo, ARRAY_CSR))
return SparseType::CSR;
else if (hasPropertyBitSet(shapeInfo, ARRAY_COO))
return SparseType::COO;
else
return SparseType::LIL;
}
FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBits(Nd4jLong *shapeInfo, std::initializer_list<int> properties) {
for (auto v: properties) {
if (!hasPropertyBitSet(shapeInfo, v))
setPropertyBit(shapeInfo, v);
}
}
FORCEINLINE _CUDA_HD void ArrayOptions::resetDataType(Nd4jLong *shapeInfo) {
unsetPropertyBit(shapeInfo, ARRAY_BOOL);
unsetPropertyBit(shapeInfo, ARRAY_HALF);
unsetPropertyBit(shapeInfo, ARRAY_BHALF);
unsetPropertyBit(shapeInfo, ARRAY_FLOAT);
unsetPropertyBit(shapeInfo, ARRAY_DOUBLE);
unsetPropertyBit(shapeInfo, ARRAY_INT);
unsetPropertyBit(shapeInfo, ARRAY_LONG);
unsetPropertyBit(shapeInfo, ARRAY_CHAR);
unsetPropertyBit(shapeInfo, ARRAY_SHORT);
unsetPropertyBit(shapeInfo, ARRAY_UNSIGNED);
}
FORCEINLINE _CUDA_HD void ArrayOptions::setDataType(Nd4jLong *shapeInfo, const nd4j::DataType dataType) {
resetDataType(shapeInfo);
if (dataType == nd4j::DataType::UINT8 ||
dataType == nd4j::DataType::UINT16 ||
dataType == nd4j::DataType::UINT32 ||
dataType == nd4j::DataType::UINT64) {
setPropertyBit(shapeInfo, ARRAY_UNSIGNED);
}
switch (dataType) {
case nd4j::DataType::BOOL:
setPropertyBit(shapeInfo, ARRAY_BOOL);
break;
case nd4j::DataType::HALF:
setPropertyBit(shapeInfo, ARRAY_HALF);
break;
case nd4j::DataType::BFLOAT16:
setPropertyBit(shapeInfo, ARRAY_BHALF);
break;
case nd4j::DataType::FLOAT32:
setPropertyBit(shapeInfo, ARRAY_FLOAT);
break;
case nd4j::DataType::DOUBLE:
setPropertyBit(shapeInfo, ARRAY_DOUBLE);
break;
case nd4j::DataType::INT8:
setPropertyBit(shapeInfo, ARRAY_CHAR);
break;
case nd4j::DataType::INT16:
setPropertyBit(shapeInfo, ARRAY_SHORT);
break;
case nd4j::DataType::INT32:
setPropertyBit(shapeInfo, ARRAY_INT);
break;
case nd4j::DataType::INT64:
setPropertyBit(shapeInfo, ARRAY_LONG);
break;
case nd4j::DataType::UINT8:
setPropertyBit(shapeInfo, ARRAY_CHAR);
break;
case nd4j::DataType::UINT16:
setPropertyBit(shapeInfo, ARRAY_SHORT);
break;
case nd4j::DataType::UINT32:
setPropertyBit(shapeInfo, ARRAY_INT);
break;
case nd4j::DataType::UINT64:
setPropertyBit(shapeInfo, ARRAY_LONG);
break;
case nd4j::DataType::UTF8:
setPropertyBit(shapeInfo, ARRAY_UTF8);
break;
case nd4j::DataType::UTF16:
setPropertyBit(shapeInfo, ARRAY_UTF16);
break;
case nd4j::DataType::UTF32:
setPropertyBit(shapeInfo, ARRAY_UTF32);
break;
default:
#ifndef __CUDA_ARCH__
throw std::runtime_error("Can't set unknown data type");
#else
printf("Can't set unknown data type");
#endif
}
}
////////////////////////////////////////////////////////////////////////////////
FORCEINLINE _CUDA_HD void ArrayOptions::copyDataType(Nd4jLong* to, const Nd4jLong* from) {
setDataType(to, dataType(from));
}
}
#endif // ND4J_ARRAY_OPTIONS_H :)