* initial commit * additional data types & tensor type Signed-off-by: raver119 <raver119@gmail.com> * next step Signed-off-by: raver119 <raver119@gmail.com> * missing include * sparse_to_dense Signed-off-by: raver119 <raver119@gmail.com> * few more tests files Signed-off-by: raver119 <raver119@gmail.com> * draft Signed-off-by: raver119 <raver119@gmail.com> * numeric sparse_to_dense Signed-off-by: raver119 <raver119@gmail.com> * comment Signed-off-by: raver119 <raver119@gmail.com> * string sparse_to_dense version Signed-off-by: raver119 <raver119@gmail.com> * CUDA DataBuffer expand Signed-off-by: raver119 <raver119@gmail.com> * few tweaks for CUDA build Signed-off-by: raver119 <raver119@gmail.com> * shape fn for string_split Signed-off-by: raver119 <raver119@gmail.com> * one more comment Signed-off-by: raver119 <raver119@gmail.com> * string_split indices Signed-off-by: raver119 <raver119@gmail.com> * next step Signed-off-by: raver119 <raver119@gmail.com> * test passes Signed-off-by: raver119 <raver119@gmail.com> * few rearrangements for databuffer implementations Signed-off-by: raver119 <raver119@gmail.com> * DataBuffer: move inline methods to common implementations Signed-off-by: raver119 <raver119@gmail.com> * add native DataBuffer to Nd4j presets Signed-off-by: raver119 <raver119@gmail.com> * DataBuffer creation Signed-off-by: raver119 <raver119@gmail.com> * use DataBuffer for allocation Signed-off-by: raver119 <raver119@gmail.com> * cpu databuffer as deallocatable Signed-off-by: raver119 <raver119@gmail.com> * DataBuffer setters for bufers Signed-off-by: raver119 <raver119@gmail.com> * couple of wrappers Signed-off-by: raver119 <raver119@gmail.com> * DataBuffers being passed around Signed-off-by: raver119 <raver119@gmail.com> * Bunch of ByteBuffer-related signatures gone Signed-off-by: raver119 <raver119@gmail.com> * - few more Nd4j signatures removed - minor fix for bfloat16 Signed-off-by: raver119 <raver119@gmail.com> * nullptr pointer is still a pointer, but 0 as address :) Signed-off-by: raver119 <raver119@gmail.com> * one special test Signed-off-by: raver119 <raver119@gmail.com> * empty string array init Signed-off-by: raver119 <raver119@gmail.com> * one more test in cpp Signed-off-by: raver119 <raver119@gmail.com> * memcpy instead of databuffer swap Signed-off-by: raver119 <raver119@gmail.com> * special InteropDataBuffer for front-end languages Signed-off-by: raver119 <raver119@gmail.com> * few tweaks for java Signed-off-by: raver119 <raver119@gmail.com> * pointer/indexer actualization Signed-off-by: raver119 <raver119@gmail.com> * CustomOp returns list for inputArumgents and outputArguments instead of array Signed-off-by: raver119 <raver119@gmail.com> * redundant call Signed-off-by: raver119 <raver119@gmail.com> * print_variable op Signed-off-by: raver119 <raver119@gmail.com> * - view handling (but wrong one) - print_variable java wrapper Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * - empty arrays handling Signed-off-by: raver119 <raver119@gmail.com> * - deserialization works now Signed-off-by: raver119 <raver119@gmail.com> * minor fix Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * one more fix Signed-off-by: raver119 <raver119@gmail.com> * initial cuda commit Signed-off-by: raver119 <raver119@gmail.com> * print_variable message validation Signed-off-by: raver119 <raver119@gmail.com> * CUDA views Signed-off-by: raver119 <raver119@gmail.com> * CUDA special buffer size Signed-off-by: raver119 <raver119@gmail.com> * minor update to match master changes Signed-off-by: raver119 <raver119@gmail.com> * - consider arrays always actual on device for CUDA - additional PrintVariable constructor - CudaUtf8Buffer now allocates host buffer by default Signed-off-by: raver119 <raver119@gmail.com> * meh Signed-off-by: raver119 <raver119@gmail.com> * - print_variable now allows print from device Signed-off-by: raver119 <raver119@gmail.com> * InteropDataBuffer data type fix Signed-off-by: raver119 <raver119@gmail.com> * ... Signed-off-by: raver119 <raver119@gmail.com> * disable some debug messages Signed-off-by: raver119 <raver119@gmail.com> * master pulled in Signed-off-by: raver119 <raver119@gmail.com> * couple of new methods for DataBuffer interop Signed-off-by: raver119 <raver119@gmail.com> * java side Signed-off-by: raver119 <raver119@gmail.com> * offsetted constructor Signed-off-by: raver119 <raver119@gmail.com> * new CUDA deallocator Signed-off-by: raver119 <raver119@gmail.com> * CUDA backend torn apart Signed-off-by: raver119 <raver119@gmail.com> * CUDA backend torn apart 2 Signed-off-by: raver119 <raver119@gmail.com> * CUDA backend torn apart 3 Signed-off-by: raver119 <raver119@gmail.com> * - few new tests - few new methods for DataBuffer management Signed-off-by: raver119 <raver119@gmail.com> * few more tests + few more tweaks Signed-off-by: raver119 <raver119@gmail.com> * two failing tests Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * two failing tests pass Signed-off-by: raver119 <raver119@gmail.com> * now we pass DataBuffer to legacy ops too Signed-off-by: raver119 <raver119@gmail.com> * Native DataBuffer for legacy ops, Java side Signed-off-by: raver119 <raver119@gmail.com> * CPU java side update Signed-off-by: raver119 <raver119@gmail.com> * CUDA java side update Signed-off-by: raver119 <raver119@gmail.com> * no more prepare/register action on java side Signed-off-by: raver119 <raver119@gmail.com> * NDArray::prepare/register use now accepts vectors Signed-off-by: raver119 <raver119@gmail.com> * InteropDataBuffer now has few more convenience methods Signed-off-by: raver119 <raver119@gmail.com> * java bindings update Signed-off-by: raver119 <raver119@gmail.com> * tick device in NativeOps Signed-off-by: raver119 <raver119@gmail.com> * Corrected usage of OpaqueBuffer for tests. * Corrected usage of OpaqueBuffer for java tests. * NativeOpsTests fixes. * print_variable now returns scalar Signed-off-by: raver119 <raver119@gmail.com> * one more test Signed-off-by: raver119 <raver119@gmail.com> * compat_string_split fix for CUDA Signed-off-by: raver119 <raver119@gmail.com> * - CUDA execScalar fix - CUDA lazyAllocateHostPointer now checks java indexer/pointer instead of native pointer Signed-off-by: raver119 <raver119@gmail.com> * legacy ops DataBuffer migration prototype Signed-off-by: raver119 <raver119@gmail.com> * ignore device shapeinfo coming from java Signed-off-by: raver119 <raver119@gmail.com> * minor fix Signed-off-by: raver119 <raver119@gmail.com> * minor transformAny fix Signed-off-by: raver119 <raver119@gmail.com> * minor tweak for lazy host allocation Signed-off-by: raver119 <raver119@gmail.com> * - DataBuffer::memcpy method - bitcast now uses memcpy Signed-off-by: raver119 <raver119@gmail.com> * - IndexReduce CUDA dimension buffer fix Signed-off-by: raver119 <raver119@gmail.com> * views for CPU and CUDA Signed-off-by: raver119 <raver119@gmail.com> * less spam Signed-off-by: raver119 <raver119@gmail.com> * optional memory init Signed-off-by: raver119 <raver119@gmail.com> * async memset Signed-off-by: raver119 <raver119@gmail.com> * - SummaryStats CUDA fix - DataBuffer.sameUnderlyingData() impl - execBroadcast fix Signed-off-by: raver119 <raver119@gmail.com> * - reduce3All fix switch to CUDA 10 temporarily Signed-off-by: raver119 <raver119@gmail.com> * CUDA version Signed-off-by: raver119 <raver119@gmail.com> * proper memory deallocator registration Signed-off-by: raver119 <raver119@gmail.com> * HOST_ONLY workspace allocation Signed-off-by: raver119 <raver119@gmail.com> * temp commit Signed-off-by: raver119 <raver119@gmail.com> * few conflicts resolved Signed-off-by: raver119 <raver119@gmail.com> * few minor fixes Signed-off-by: raver119 <raver119@gmail.com> * one more minor fix Signed-off-by: raver119 <raver119@gmail.com> * NDArray permute should operate on JVM primitives Signed-off-by: raver119 <raver119@gmail.com> * - create InteropDataBuffer for shapes as well - update pointers after view creation in Java Signed-off-by: raver119 <raver119@gmail.com> * - addressPointer temporary moved to C++ Signed-off-by: raver119 <raver119@gmail.com> * CUDA: don't account offset twice Signed-off-by: raver119 <raver119@gmail.com> * CUDA: DataBuffer pointer constructor updated Signed-off-by: raver119 <raver119@gmail.com> * CUDA NDArray.unsafeDuplication() simplified Signed-off-by: raver119 <raver119@gmail.com> * CUDA minor workspace-related fixes Signed-off-by: raver119 <raver119@gmail.com> * CPU DataBuffer.reallocate() Signed-off-by: raver119 <raver119@gmail.com> * print_affinity op Signed-off-by: raver119 <raver119@gmail.com> * print_affinity java side Signed-off-by: raver119 <raver119@gmail.com> * CUDA more tweaks for data locality Signed-off-by: raver119 <raver119@gmail.com> * - compat_string_split tweak - CudaUtf8Buffer update Signed-off-by: raver119 <raver119@gmail.com> * INDArray.close() mechanic restored Signed-off-by: raver119 <raver119@gmail.com> * one more test fixed Signed-off-by: raver119 <raver119@gmail.com> * - CUDA DataBuffer.reallocate() updated - cudaMemcpy (synchronous) restored Signed-off-by: raver119 <raver119@gmail.com> * one last fix Signed-off-by: raver119 <raver119@gmail.com> * bad import removed Signed-off-by: raver119 <raver119@gmail.com> * another small fix Signed-off-by: raver119 <raver119@gmail.com> * one special test Signed-off-by: raver119 <raver119@gmail.com> * fix bad databuffer size Signed-off-by: raver119 <raver119@gmail.com> * release primaryBuffer on replace Signed-off-by: raver119 <raver119@gmail.com> * higher timeout Signed-off-by: raver119 <raver119@gmail.com> * disable timeouts Signed-off-by: raver119 <raver119@gmail.com> * dbCreateView now validates offset and length of a view Signed-off-by: raver119 <raver119@gmail.com> * additional validation for dbExpand Signed-off-by: raver119 <raver119@gmail.com> * restore timeout back again Signed-off-by: raver119 <raver119@gmail.com> * smaller distribution for rng test to prevent timeouts Signed-off-by: raver119 <raver119@gmail.com> * CUDA DataBuffer::memcpy now copies to device all the time Signed-off-by: raver119 <raver119@gmail.com> * OpaqueDataBuffer now contains all required methods for interop Signed-off-by: raver119 <raver119@gmail.com> * some javadoc Signed-off-by: raver119 <raver119@gmail.com> * GC on failed allocations Signed-off-by: raver119 <raver119@gmail.com> * minoe memcpu tweak Signed-off-by: raver119 <raver119@gmail.com> * one more bitcast test Signed-off-by: raver119 <raver119@gmail.com> * - NDArray::deviceId() propagation - special multi-threaded test for data locality checks Signed-off-by: raver119 <raver119@gmail.com> * DataBuffer additional syncStream Signed-off-by: raver119 <raver119@gmail.com> * DataBuffer additional syncStream Signed-off-by: raver119 <raver119@gmail.com> * one ignored test Signed-off-by: raver119 <raver119@gmail.com> * skip host alloc for empty arrays Signed-off-by: raver119 <raver119@gmail.com> * ByteBuffer support is back Signed-off-by: raver119 <raver119@gmail.com> * DataBuffer::memcpy minor fix Signed-off-by: raver119 <raver119@gmail.com> * few minor prelu/bp tweaks Signed-off-by: raver119 <raver119@gmail.com> * nullify-related fixes Signed-off-by: raver119 <raver119@gmail.com> * PReLU fixes (#157) Signed-off-by: Alex Black <blacka101@gmail.com> * Build fixed * Fix tests * one more ByteBuffer signature restored Signed-off-by: raver119 <raver119@gmail.com> * nd4j-jdbc-hsql profiles fix Signed-off-by: raver119 <raver119@gmail.com> * nd4j-jdbc-hsql profiles fix Signed-off-by: raver119 <raver119@gmail.com> * PReLU weight init fix Signed-off-by: Alex Black <blacka101@gmail.com> * Small PReLU fix Signed-off-by: Alex Black <blacka101@gmail.com> * - INDArray.migrate() reactivated - DataBuffer::setDeviceId(...) added - InteropDataBuffer Z syncToDevice added for views Signed-off-by: raver119 <raver119@gmail.com> * missed file Signed-off-by: raver119 <raver119@gmail.com> * Small tweak Signed-off-by: Alex Black <blacka101@gmail.com> * cuda 10.2 Signed-off-by: raver119 <raver119@gmail.com> * minor fix Signed-off-by: raver119 <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com> Co-authored-by: Alex Black <blacka101@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
373 lines
13 KiB
C++
373 lines
13 KiB
C++
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
//
|
|
// @author raver119@gmail.com
|
|
//
|
|
|
|
#ifndef ND4J_ARRAY_OPTIONS_H
|
|
#define ND4J_ARRAY_OPTIONS_H
|
|
|
|
#include <op_boilerplate.h>
|
|
#include <pointercast.h>
|
|
#include <dll.h>
|
|
#include <array/DataType.h>
|
|
#include <array/ArrayType.h>
|
|
#include <array/SpaceType.h>
|
|
#include <array/SparseType.h>
|
|
#include <initializer_list>
|
|
|
|
|
|
#define ARRAY_SPARSE 2
|
|
#define ARRAY_COMPRESSED 4
|
|
#define ARRAY_EMPTY 8
|
|
#define ARRAY_RAGGED 16
|
|
|
|
|
|
#define ARRAY_CSR 32
|
|
#define ARRAY_CSC 64
|
|
#define ARRAY_COO 128
|
|
|
|
// complex values
|
|
#define ARRAY_COMPLEX 512
|
|
|
|
// quantized values
|
|
#define ARRAY_QUANTIZED 1024
|
|
|
|
// 16 bit float FP16
|
|
#define ARRAY_HALF 4096
|
|
|
|
// 16 bit bfloat16
|
|
#define ARRAY_BHALF 2048
|
|
|
|
// regular 32 bit float
|
|
#define ARRAY_FLOAT 8192
|
|
|
|
// regular 64 bit float
|
|
#define ARRAY_DOUBLE 16384
|
|
|
|
// 8 bit integer
|
|
#define ARRAY_CHAR 32768
|
|
|
|
// 16 bit integer
|
|
#define ARRAY_SHORT 65536
|
|
|
|
// 32 bit integer
|
|
#define ARRAY_INT 131072
|
|
|
|
// 64 bit integer
|
|
#define ARRAY_LONG 262144
|
|
|
|
// boolean values
|
|
#define ARRAY_BOOL 524288
|
|
|
|
// UTF values
|
|
#define ARRAY_UTF8 1048576
|
|
#define ARRAY_UTF16 4194304
|
|
#define ARRAY_UTF32 16777216
|
|
|
|
// flag for extras
|
|
#define ARRAY_EXTRAS 2097152
|
|
|
|
|
|
// flag for signed/unsigned integers
|
|
#define ARRAY_UNSIGNED 8388608
|
|
|
|
|
|
namespace nd4j {
|
|
class ND4J_EXPORT ArrayOptions {
|
|
|
|
private:
|
|
static FORCEINLINE _CUDA_HD Nd4jLong& extra(Nd4jLong* shape);
|
|
|
|
public:
|
|
static FORCEINLINE _CUDA_HD bool isNewFormat(const Nd4jLong *shapeInfo);
|
|
static FORCEINLINE _CUDA_HD bool hasPropertyBitSet(const Nd4jLong *shapeInfo, int property);
|
|
static FORCEINLINE _CUDA_HD bool togglePropertyBit(Nd4jLong *shapeInfo, int property);
|
|
static FORCEINLINE _CUDA_HD void unsetPropertyBit(Nd4jLong *shapeInfo, int property);
|
|
static FORCEINLINE _CUDA_HD void setPropertyBit(Nd4jLong *shapeInfo, int property);
|
|
static FORCEINLINE _CUDA_HD void setPropertyBits(Nd4jLong *shapeInfo, std::initializer_list<int> properties);
|
|
|
|
static FORCEINLINE _CUDA_HD bool isSparseArray(Nd4jLong *shapeInfo);
|
|
static FORCEINLINE _CUDA_HD bool isUnsigned(Nd4jLong *shapeInfo);
|
|
|
|
static FORCEINLINE _CUDA_HD nd4j::DataType dataType(const Nd4jLong *shapeInfo);
|
|
|
|
static FORCEINLINE _CUDA_HD SpaceType spaceType(Nd4jLong *shapeInfo);
|
|
static FORCEINLINE _CUDA_HD SpaceType spaceType(const Nd4jLong *shapeInfo);
|
|
|
|
static FORCEINLINE _CUDA_HD ArrayType arrayType(Nd4jLong *shapeInfo);
|
|
static FORCEINLINE _CUDA_HD ArrayType arrayType(const Nd4jLong *shapeInfo);
|
|
|
|
static FORCEINLINE _CUDA_HD SparseType sparseType(Nd4jLong *shapeInfo);
|
|
static FORCEINLINE _CUDA_HD SparseType sparseType(const Nd4jLong *shapeInfo);
|
|
|
|
static FORCEINLINE _CUDA_HD bool hasExtraProperties(Nd4jLong *shapeInfo);
|
|
|
|
|
|
static FORCEINLINE _CUDA_HD void resetDataType(Nd4jLong *shapeInfo);
|
|
static FORCEINLINE _CUDA_HD void setDataType(Nd4jLong *shapeInfo, const nd4j::DataType dataType);
|
|
|
|
static FORCEINLINE _CUDA_HD void copyDataType(Nd4jLong* to, const Nd4jLong* from);
|
|
};
|
|
|
|
FORCEINLINE _CUDA_HD Nd4jLong& ArrayOptions::extra(Nd4jLong* shape) {
|
|
return shape[shape[0] + shape[0] + 1];
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD bool ArrayOptions::isNewFormat(const Nd4jLong *shapeInfo) {
|
|
return (extra(const_cast<Nd4jLong*>(shapeInfo)) != 0);
|
|
}
|
|
|
|
|
|
FORCEINLINE _CUDA_HD bool ArrayOptions::isSparseArray(Nd4jLong *shapeInfo) {
|
|
return hasPropertyBitSet(shapeInfo, ARRAY_SPARSE);
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD bool ArrayOptions::hasExtraProperties(Nd4jLong *shapeInfo) {
|
|
return hasPropertyBitSet(shapeInfo, ARRAY_EXTRAS);
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD bool ArrayOptions::hasPropertyBitSet(const Nd4jLong *shapeInfo, int property) {
|
|
if (!isNewFormat(shapeInfo))
|
|
return false;
|
|
|
|
return ((extra(const_cast<Nd4jLong*>(shapeInfo)) & property) == property);
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD bool ArrayOptions::isUnsigned(Nd4jLong *shapeInfo) {
|
|
if (!isNewFormat(shapeInfo))
|
|
return false;
|
|
|
|
return hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED);
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD nd4j::DataType ArrayOptions::dataType(const Nd4jLong *shapeInfo) {
|
|
/*if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED))
|
|
return nd4j::DataType::QINT8;
|
|
else */if (hasPropertyBitSet(shapeInfo, ARRAY_FLOAT))
|
|
return nd4j::DataType::FLOAT32;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_DOUBLE))
|
|
return nd4j::DataType::DOUBLE;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_HALF))
|
|
return nd4j::DataType::HALF;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_BHALF))
|
|
return nd4j::DataType::BFLOAT16;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_BOOL))
|
|
return nd4j::DataType ::BOOL;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UNSIGNED)) {
|
|
if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR))
|
|
return nd4j::DataType ::UINT8;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT))
|
|
return nd4j::DataType ::UINT16;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_INT))
|
|
return nd4j::DataType ::UINT32;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
|
|
return nd4j::DataType ::UINT64;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
|
|
return nd4j::DataType ::UTF8;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
|
|
return nd4j::DataType ::UTF16;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
|
|
return nd4j::DataType ::UTF32;
|
|
else {
|
|
//shape::printShapeInfoLinear("Bad unsigned datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
|
|
#ifndef __CUDA_ARCH__
|
|
throw std::runtime_error("Bad datatype A");
|
|
#endif
|
|
}
|
|
}
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_CHAR))
|
|
return nd4j::DataType::INT8;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_SHORT))
|
|
return nd4j::DataType::INT16;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_INT))
|
|
return nd4j::DataType::INT32;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_LONG))
|
|
return nd4j::DataType::INT64;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF8))
|
|
return nd4j::DataType::UTF8;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF16))
|
|
return nd4j::DataType::UTF16;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_UTF32))
|
|
return nd4j::DataType::UTF32;
|
|
else {
|
|
//shape::printShapeInfoLinear("Bad signed datatype (not)stored in shape", const_cast<Nd4jLong*>(shapeInfo));
|
|
#ifndef __CUDA_ARCH__
|
|
throw std::runtime_error("Bad datatype B");
|
|
#endif
|
|
}
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD SpaceType ArrayOptions::spaceType(const Nd4jLong *shapeInfo) {
|
|
return spaceType(const_cast<Nd4jLong *>(shapeInfo));
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD SpaceType ArrayOptions::spaceType(Nd4jLong *shapeInfo) {
|
|
if (hasPropertyBitSet(shapeInfo, ARRAY_QUANTIZED))
|
|
return SpaceType::QUANTIZED;
|
|
if (hasPropertyBitSet(shapeInfo, ARRAY_COMPLEX))
|
|
return SpaceType::COMPLEX;
|
|
else // by default we return continuous type here
|
|
return SpaceType::CONTINUOUS;
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD ArrayType ArrayOptions::arrayType(const Nd4jLong *shapeInfo) {
|
|
return arrayType(const_cast<Nd4jLong *>(shapeInfo));
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD ArrayType ArrayOptions::arrayType(Nd4jLong *shapeInfo) {
|
|
if (hasPropertyBitSet(shapeInfo, ARRAY_SPARSE))
|
|
return ArrayType::SPARSE;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_COMPRESSED))
|
|
return ArrayType::COMPRESSED;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_EMPTY))
|
|
return ArrayType::EMPTY;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_RAGGED))
|
|
return ArrayType::RAGGED;
|
|
else // by default we return DENSE type here
|
|
return ArrayType::DENSE;
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD bool ArrayOptions::togglePropertyBit(Nd4jLong *shapeInfo, int property) {
|
|
extra(shapeInfo) ^= property;
|
|
|
|
return hasPropertyBitSet(shapeInfo, property);
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBit(Nd4jLong *shapeInfo, int property) {
|
|
extra(shapeInfo) |= property;
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD void ArrayOptions::unsetPropertyBit(Nd4jLong *shapeInfo, int property) {
|
|
extra(shapeInfo) &= property;
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD SparseType ArrayOptions::sparseType(const Nd4jLong *shapeInfo) {
|
|
return sparseType(const_cast<Nd4jLong *>(shapeInfo));
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD SparseType ArrayOptions::sparseType(Nd4jLong *shapeInfo) {
|
|
#ifndef __CUDA_ARCH__
|
|
if (!isSparseArray(shapeInfo))
|
|
throw std::runtime_error("Not a sparse array");
|
|
#endif
|
|
|
|
if (hasPropertyBitSet(shapeInfo, ARRAY_CSC))
|
|
return SparseType::CSC;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_CSR))
|
|
return SparseType::CSR;
|
|
else if (hasPropertyBitSet(shapeInfo, ARRAY_COO))
|
|
return SparseType::COO;
|
|
else
|
|
return SparseType::LIL;
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD void ArrayOptions::setPropertyBits(Nd4jLong *shapeInfo, std::initializer_list<int> properties) {
|
|
for (auto v: properties) {
|
|
if (!hasPropertyBitSet(shapeInfo, v))
|
|
setPropertyBit(shapeInfo, v);
|
|
}
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD void ArrayOptions::resetDataType(Nd4jLong *shapeInfo) {
|
|
unsetPropertyBit(shapeInfo, ARRAY_BOOL);
|
|
unsetPropertyBit(shapeInfo, ARRAY_HALF);
|
|
unsetPropertyBit(shapeInfo, ARRAY_BHALF);
|
|
unsetPropertyBit(shapeInfo, ARRAY_FLOAT);
|
|
unsetPropertyBit(shapeInfo, ARRAY_DOUBLE);
|
|
unsetPropertyBit(shapeInfo, ARRAY_INT);
|
|
unsetPropertyBit(shapeInfo, ARRAY_LONG);
|
|
unsetPropertyBit(shapeInfo, ARRAY_CHAR);
|
|
unsetPropertyBit(shapeInfo, ARRAY_SHORT);
|
|
unsetPropertyBit(shapeInfo, ARRAY_UNSIGNED);
|
|
}
|
|
|
|
FORCEINLINE _CUDA_HD void ArrayOptions::setDataType(Nd4jLong *shapeInfo, const nd4j::DataType dataType) {
|
|
resetDataType(shapeInfo);
|
|
if (dataType == nd4j::DataType::UINT8 ||
|
|
dataType == nd4j::DataType::UINT16 ||
|
|
dataType == nd4j::DataType::UINT32 ||
|
|
dataType == nd4j::DataType::UINT64) {
|
|
|
|
setPropertyBit(shapeInfo, ARRAY_UNSIGNED);
|
|
}
|
|
|
|
switch (dataType) {
|
|
case nd4j::DataType::BOOL:
|
|
setPropertyBit(shapeInfo, ARRAY_BOOL);
|
|
break;
|
|
case nd4j::DataType::HALF:
|
|
setPropertyBit(shapeInfo, ARRAY_HALF);
|
|
break;
|
|
case nd4j::DataType::BFLOAT16:
|
|
setPropertyBit(shapeInfo, ARRAY_BHALF);
|
|
break;
|
|
case nd4j::DataType::FLOAT32:
|
|
setPropertyBit(shapeInfo, ARRAY_FLOAT);
|
|
break;
|
|
case nd4j::DataType::DOUBLE:
|
|
setPropertyBit(shapeInfo, ARRAY_DOUBLE);
|
|
break;
|
|
case nd4j::DataType::INT8:
|
|
setPropertyBit(shapeInfo, ARRAY_CHAR);
|
|
break;
|
|
case nd4j::DataType::INT16:
|
|
setPropertyBit(shapeInfo, ARRAY_SHORT);
|
|
break;
|
|
case nd4j::DataType::INT32:
|
|
setPropertyBit(shapeInfo, ARRAY_INT);
|
|
break;
|
|
case nd4j::DataType::INT64:
|
|
setPropertyBit(shapeInfo, ARRAY_LONG);
|
|
break;
|
|
case nd4j::DataType::UINT8:
|
|
setPropertyBit(shapeInfo, ARRAY_CHAR);
|
|
break;
|
|
case nd4j::DataType::UINT16:
|
|
setPropertyBit(shapeInfo, ARRAY_SHORT);
|
|
break;
|
|
case nd4j::DataType::UINT32:
|
|
setPropertyBit(shapeInfo, ARRAY_INT);
|
|
break;
|
|
case nd4j::DataType::UINT64:
|
|
setPropertyBit(shapeInfo, ARRAY_LONG);
|
|
break;
|
|
case nd4j::DataType::UTF8:
|
|
setPropertyBit(shapeInfo, ARRAY_UTF8);
|
|
break;
|
|
case nd4j::DataType::UTF16:
|
|
setPropertyBit(shapeInfo, ARRAY_UTF16);
|
|
break;
|
|
case nd4j::DataType::UTF32:
|
|
setPropertyBit(shapeInfo, ARRAY_UTF32);
|
|
break;
|
|
default:
|
|
#ifndef __CUDA_ARCH__
|
|
throw std::runtime_error("Can't set unknown data type");
|
|
#else
|
|
printf("Can't set unknown data type");
|
|
#endif
|
|
}
|
|
}
|
|
|
|
////////////////////////////////////////////////////////////////////////////////
|
|
FORCEINLINE _CUDA_HD void ArrayOptions::copyDataType(Nd4jLong* to, const Nd4jLong* from) {
|
|
setDataType(to, dataType(from));
|
|
}
|
|
}
|
|
|
|
#endif // ND4J_ARRAY_OPTIONS_H :)
|