/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #ifndef ND4J_ARRAY_OPTIONS_H #define ND4J_ARRAY_OPTIONS_H #include #include #include #include #include #include #include #include #define ARRAY_SPARSE 2 #define ARRAY_COMPRESSED 4 #define ARRAY_EMPTY 8 #define ARRAY_RAGGED 16 #define ARRAY_CSR 32 #define ARRAY_CSC 64 #define ARRAY_COO 128 // complex values #define ARRAY_COMPLEX 512 // quantized values #define ARRAY_QUANTIZED 1024 // 16 bit float FP16 #define ARRAY_HALF 4096 // 16 bit bfloat16 #define ARRAY_BHALF 2048 // regular 32 bit float #define ARRAY_FLOAT 8192 // regular 64 bit float #define ARRAY_DOUBLE 16384 // 8 bit integer #define ARRAY_CHAR 32768 // 16 bit integer #define ARRAY_SHORT 65536 // 32 bit integer #define ARRAY_INT 131072 // 64 bit integer #define ARRAY_LONG 262144 // boolean values #define ARRAY_BOOL 524288 // UTF values #define ARRAY_UTF8 1048576 #define ARRAY_UTF16 4194304 #define ARRAY_UTF32 16777216 // flag for extras #define ARRAY_EXTRAS 2097152 // flag for signed/unsigned integers #define ARRAY_UNSIGNED 8388608 namespace sd { class ND4J_EXPORT ArrayOptions { private: static FORCEINLINE _CUDA_HD Nd4jLong& extra(Nd4jLong* shape); public: static FORCEINLINE _CUDA_HD bool isNewFormat(const Nd4jLong *shapeInfo); static FORCEINLINE _CUDA_HD bool hasPropertyBitSet(const Nd4jLong *shapeInfo, int property); static FORCEINLINE _CUDA_HD bool togglePropertyBit(Nd4jLong *shapeInfo, int property); static FORCEINLINE _CUDA_HD void unsetPropertyBit(Nd4jLong *shapeInfo, int property); static FORCEINLINE _CUDA_HD void setPropertyBit(Nd4jLong *shapeInfo, int property); static FORCEINLINE _CUDA_HD void setPropertyBits(Nd4jLong *shapeInfo, std::initializer_list properties); static FORCEINLINE _CUDA_HD bool isSparseArray(Nd4jLong *shapeInfo); static FORCEINLINE _CUDA_HD bool isUnsigned(Nd4jLong *shapeInfo); static FORCEINLINE _CUDA_HD sd::DataType dataType(const Nd4jLong *shapeInfo); static FORCEINLINE _CUDA_HD SpaceType spaceType(Nd4jLong *shapeInfo); static FORCEINLINE _CUDA_HD SpaceType spaceType(const Nd4jLong *shapeInfo); static FORCEINLINE _CUDA_HD ArrayType arrayType(Nd4jLong *shapeInfo); static FORCEINLINE _CUDA_HD ArrayType arrayType(const Nd4jLong *shapeInfo); static FORCEINLINE _CUDA_HD SparseType sparseType(Nd4jLong *shapeInfo); static FORCEINLINE _CUDA_HD SparseType sparseType(const Nd4jLong *shapeInfo); static FORCEINLINE _CUDA_HD bool hasExtraProperties(Nd4jLong *shapeInfo); static FORCEINLINE _CUDA_HD void resetDataType(Nd4jLong *shapeInfo); static FORCEINLINE _CUDA_HD void setDataType(Nd4jLong *shapeInfo, const sd::DataType dataType); static FORCEINLINE _CUDA_HD void copyDataType(Nd4jLong* to, const Nd4jLong* from); }; FORCEINLINE _CUDA_HD Nd4jLong& ArrayOptions::extra(Nd4jLong* shape) { return shape[shape[0] + shape[0] + 1]; } FORCEINLINE _CUDA_HD bool ArrayOptions::isNewFormat(const Nd4jLong *shapeInfo) { return (extra(const_cast(shapeInfo)) != 0); } FORCEINLINE _CUDA_HD bool ArrayOptions::isSparseArray(Nd4jLong *shapeInfo) { return hasPropertyBitSet(shapeInfo, ARRAY_SPARSE); } FORCEINLINE _CUDA_HD bool ArrayOptions::hasExtraProperties(Nd4jLong *shapeInfo) { return hasPropertyBitSet(shapeInfo, ARRAY_EXTRAS); } FORCEINLINE _CUDA_HD bool ArrayOptions::hasPropertyBitSet(const Nd4jLong *shapeInfo, int property) { if (!isNewFormat(shapeInfo)) return false; return ((extra(const_cast(shapeInfo)) & property) == property); } FORCEINLINE _CUDA_HD bool ArrayOptions::isUnsigned(Nd4jLong *shapeInfo) { if (!isNewFormat(shapeInfo)) return false; return hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED); } FORCEINLINE _CUDA_HD sd::DataType ArrayOptions::dataType(const Nd4jLong *shapeInfo) { /*if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED)) return sd::DataType::QINT8; else */if (hasPropertyBitSet(shapeInfo, ARRAY_FLOAT)) return sd::DataType::FLOAT32; else if (hasPropertyBitSet(shapeInfo, ARRAY_DOUBLE)) return sd::DataType::DOUBLE; else if (hasPropertyBitSet(shapeInfo, ARRAY_HALF)) return sd::DataType::HALF; else if (hasPropertyBitSet(shapeInfo, ARRAY_BHALF)) return sd::DataType::BFLOAT16; else if (hasPropertyBitSet(shapeInfo, ARRAY_BOOL)) return sd::DataType ::BOOL; else if (hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED)) { if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR)) return sd::DataType ::UINT8; else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT)) return sd::DataType ::UINT16; else if (hasPropertyBitSet(shapeInfo, ARRAY_INT)) return sd::DataType ::UINT32; else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) return sd::DataType ::UINT64; else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) return sd::DataType ::UTF8; else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) return sd::DataType ::UTF16; else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) return sd::DataType ::UTF32; else { //shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast(shapeInfo)); #ifndef __CUDA_ARCH__ throw std::runtime_error("Bad datatype A"); #endif } } else if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR)) return sd::DataType::INT8; else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT)) return sd::DataType::INT16; else if (hasPropertyBitSet(shapeInfo, ARRAY_INT)) return sd::DataType::INT32; else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG)) return sd::DataType::INT64; else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8)) return sd::DataType::UTF8; else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16)) return sd::DataType::UTF16; else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32)) return sd::DataType::UTF32; else { //shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast(shapeInfo)); #ifndef __CUDA_ARCH__ throw std::runtime_error("Bad datatype B"); #endif } } FORCEINLINE _CUDA_HD SpaceType ArrayOptions::spaceType(const Nd4jLong *shapeInfo) { return spaceType(const_cast(shapeInfo)); } FORCEINLINE _CUDA_HD SpaceType ArrayOptions::spaceType(Nd4jLong *shapeInfo) { if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED)) return SpaceType::QUANTIZED; if (hasPropertyBitSet(shapeInfo, ARRAY_COMPLEX)) return SpaceType::COMPLEX; else // by default we return continuous type here return SpaceType::CONTINUOUS; } FORCEINLINE _CUDA_HD ArrayType ArrayOptions::arrayType(const Nd4jLong *shapeInfo) { return arrayType(const_cast(shapeInfo)); } FORCEINLINE _CUDA_HD ArrayType ArrayOptions::arrayType(Nd4jLong *shapeInfo) { if (hasPropertyBitSet(shapeInfo, ARRAY_SPARSE)) return ArrayType::SPARSE; else if (hasPropertyBitSet(shapeInfo, ARRAY_COMPRESSED)) return ArrayType::COMPRESSED; else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY)) return ArrayType::EMPTY; else if (hasPropertyBitSet(shapeInfo, ARRAY_RAGGED)) return ArrayType::RAGGED; else // by default we return DENSE type here return ArrayType::DENSE; } FORCEINLINE _CUDA_HD bool ArrayOptions::togglePropertyBit(Nd4jLong *shapeInfo, int property) { extra(shapeInfo) ^= property; return hasPropertyBitSet(shapeInfo, property); } FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBit(Nd4jLong *shapeInfo, int property) { extra(shapeInfo) |= property; } FORCEINLINE _CUDA_HD void ArrayOptions::unsetPropertyBit(Nd4jLong *shapeInfo, int property) { extra(shapeInfo) &= property; } FORCEINLINE _CUDA_HD SparseType ArrayOptions::sparseType(const Nd4jLong *shapeInfo) { return sparseType(const_cast(shapeInfo)); } FORCEINLINE _CUDA_HD SparseType ArrayOptions::sparseType(Nd4jLong *shapeInfo) { #ifndef __CUDA_ARCH__ if (!isSparseArray(shapeInfo)) throw std::runtime_error("Not a sparse array"); #endif if (hasPropertyBitSet(shapeInfo, ARRAY_CSC)) return SparseType::CSC; else if (hasPropertyBitSet(shapeInfo, ARRAY_CSR)) return SparseType::CSR; else if (hasPropertyBitSet(shapeInfo, ARRAY_COO)) return SparseType::COO; else return SparseType::LIL; } FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBits(Nd4jLong *shapeInfo, std::initializer_list properties) { for (auto v: properties) { if (!hasPropertyBitSet(shapeInfo, v)) setPropertyBit(shapeInfo, v); } } FORCEINLINE _CUDA_HD void ArrayOptions::resetDataType(Nd4jLong *shapeInfo) { unsetPropertyBit(shapeInfo, ARRAY_BOOL); unsetPropertyBit(shapeInfo, ARRAY_HALF); unsetPropertyBit(shapeInfo, ARRAY_BHALF); unsetPropertyBit(shapeInfo, ARRAY_FLOAT); unsetPropertyBit(shapeInfo, ARRAY_DOUBLE); unsetPropertyBit(shapeInfo, ARRAY_INT); unsetPropertyBit(shapeInfo, ARRAY_LONG); unsetPropertyBit(shapeInfo, ARRAY_CHAR); unsetPropertyBit(shapeInfo, ARRAY_SHORT); unsetPropertyBit(shapeInfo, ARRAY_UNSIGNED); } FORCEINLINE _CUDA_HD void ArrayOptions::setDataType(Nd4jLong *shapeInfo, const sd::DataType dataType) { resetDataType(shapeInfo); if (dataType == sd::DataType::UINT8 || dataType == sd::DataType::UINT16 || dataType == sd::DataType::UINT32 || dataType == sd::DataType::UINT64) { setPropertyBit(shapeInfo, ARRAY_UNSIGNED); } switch (dataType) { case sd::DataType::BOOL: setPropertyBit(shapeInfo, ARRAY_BOOL); break; case sd::DataType::HALF: setPropertyBit(shapeInfo, ARRAY_HALF); break; case sd::DataType::BFLOAT16: setPropertyBit(shapeInfo, ARRAY_BHALF); break; case sd::DataType::FLOAT32: setPropertyBit(shapeInfo, ARRAY_FLOAT); break; case sd::DataType::DOUBLE: setPropertyBit(shapeInfo, ARRAY_DOUBLE); break; case sd::DataType::INT8: setPropertyBit(shapeInfo, ARRAY_CHAR); break; case sd::DataType::INT16: setPropertyBit(shapeInfo, ARRAY_SHORT); break; case sd::DataType::INT32: setPropertyBit(shapeInfo, ARRAY_INT); break; case sd::DataType::INT64: setPropertyBit(shapeInfo, ARRAY_LONG); break; case sd::DataType::UINT8: setPropertyBit(shapeInfo, ARRAY_CHAR); break; case sd::DataType::UINT16: setPropertyBit(shapeInfo, ARRAY_SHORT); break; case sd::DataType::UINT32: setPropertyBit(shapeInfo, ARRAY_INT); break; case sd::DataType::UINT64: setPropertyBit(shapeInfo, ARRAY_LONG); break; case sd::DataType::UTF8: setPropertyBit(shapeInfo, ARRAY_UTF8); break; case sd::DataType::UTF16: setPropertyBit(shapeInfo, ARRAY_UTF16); break; case sd::DataType::UTF32: setPropertyBit(shapeInfo, ARRAY_UTF32); break; default: #ifndef __CUDA_ARCH__ throw std::runtime_error("Can't set unknown data type"); #else printf("Can't set unknown data type"); #endif } } //////////////////////////////////////////////////////////////////////////////// FORCEINLINE _CUDA_HD void ArrayOptions::copyDataType(Nd4jLong* to, const Nd4jLong* from) { setDataType(to, dataType(from)); } } #endif // ND4J_ARRAY_OPTIONS_H :)