diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt index cf9d4ff88..e54c3ebe6 100755 --- a/libnd4j/CMakeLists.txt +++ b/libnd4j/CMakeLists.txt @@ -71,8 +71,7 @@ if(NOT CUDA_BLAS) # there's a chance, we have no BLAS provided externally if ("${OPENBLAS_PATH}" STREQUAL "") - #we don't want static OpenBLAS on Apple - set(BLA_STATIC ON) + #we don't want OpenBLAS on Apple if (NOT APPLE) set(BLA_VENDOR "OpenBLAS") endif() @@ -80,23 +79,8 @@ if(NOT CUDA_BLAS) # look around for system blas instead find_package(BLAS REQUIRED) if (BLAS_FOUND) - message("Original library: ${BLAS_LIBRARIES}") - # workaround for for cmake being unable to find static blas library - SET(_TMP_B "") - if (APPLE) - string(REGEX REPLACE "\\.dylib$" ".lib" _TMP_B "${BLAS_LIBRARIES}") - elseif (WIN32) - string(REGEX REPLACE "\\.dll" ".lib" _TMP_B "${BLAS_LIBRARIES}") - else() - string(REGEX REPLACE "\\.so$" ".a" _TMP_B "${BLAS_LIBRARIES}") - endif() - set(BLAS_LIBRARIES "${_TMP_B}") - message("Found external BLAS implementation: ${BLAS_LIBRARIES} ") add_definitions(-D__EXTERNAL_BLAS__=true) - elseif(WIN32) - message("BLAS not found, using downloaded OpenBLAS instead") - add_definitions(-D__EXTERNAL_BLAS__=true) endif() else() # if we have externally provided OPENBLAS_PATH - let's use it diff --git a/libnd4j/include/array/ConstantDescriptor.h b/libnd4j/include/array/ConstantDescriptor.h index f32c1c8bf..f05f98dac 100644 --- a/libnd4j/include/array/ConstantDescriptor.h +++ b/libnd4j/include/array/ConstantDescriptor.h @@ -59,5 +59,17 @@ namespace nd4j { }; } +#ifndef __JAVACPP_HACK__ + +namespace std { + template<> + class ND4J_EXPORT hash { + public: + size_t operator()(const nd4j::ConstantDescriptor &k) const; + }; +} + +#endif + #endif //DEV_TESTS_CONSTANTDESCRIPTOR_H diff --git a/libnd4j/include/array/NDArrayList.h b/libnd4j/include/array/NDArrayList.h index 843b69a91..7d6bc12b1 100644 --- a/libnd4j/include/array/NDArrayList.h +++ b/libnd4j/include/array/NDArrayList.h @@ -44,7 +44,7 @@ namespace nd4j { nd4j::DataType _dtype; // stored chunks - std::map _chunks; + MAP_IMPL _chunks; // just a counter, for stored elements std::atomic _elements; diff --git a/libnd4j/include/array/ShapeDescriptor.h b/libnd4j/include/array/ShapeDescriptor.h index ddfd45a38..4eeaf66b9 100644 --- a/libnd4j/include/array/ShapeDescriptor.h +++ b/libnd4j/include/array/ShapeDescriptor.h @@ -85,9 +85,19 @@ class ND4J_EXPORT ShapeDescriptor { static ShapeDescriptor scalarDescriptor(const DataType type); static ShapeDescriptor vectorDescriptor(const Nd4jLong length, const DataType type); }; - - } +#ifndef __JAVACPP_HACK__ + +namespace std { + template<> + class ND4J_EXPORT hash { + public: + size_t operator()(const nd4j::ShapeDescriptor &k) const; + }; +} + +#endif + #endif //DEV_TESTS_SHAPEDESCRIPTOR_H diff --git a/libnd4j/include/array/TadDescriptor.h b/libnd4j/include/array/TadDescriptor.h index 3943a4689..ab05cbfb7 100644 --- a/libnd4j/include/array/TadDescriptor.h +++ b/libnd4j/include/array/TadDescriptor.h @@ -53,9 +53,22 @@ namespace nd4j { std::vector& axis(); ShapeDescriptor& originalShape(); + ShapeDescriptor const& originalShapeConst() const; bool areUnitiesinShape() const; }; } +#ifndef __JAVACPP_HACK__ + +namespace std { + template<> + class ND4J_EXPORT hash { + public: + size_t operator()(const nd4j::TadDescriptor &k) const; + }; +} + +#endif + #endif //DEV_TESTS_TADDESCRIPTOR_H diff --git a/libnd4j/include/array/impl/ConstantDescriptor.cpp b/libnd4j/include/array/impl/ConstantDescriptor.cpp index b64523096..d53ef0adc 100644 --- a/libnd4j/include/array/impl/ConstantDescriptor.cpp +++ b/libnd4j/include/array/impl/ConstantDescriptor.cpp @@ -75,3 +75,25 @@ namespace nd4j { return isInteger() ? _integerValues.size() : isFloat() ? _floatValues.size() : 0L; } } + +namespace std { + size_t hash::operator()(const nd4j::ConstantDescriptor &k) const { + using std::hash; + // Compute individual hash values for first, + // second and third and combine them using XOR + // and bit shifting: + size_t hashVal = 0; + size_t i = 0; + if (k.isInteger()) { + for (auto v: k.integerValues()) { + hashVal ^= std::hash()(v) + 0x9e3779b9 + (hashVal << 6) + (hashVal >> 2); + } + } + else { + for (auto v: k.floatValues()) { + hashVal ^= std::hash()(v) + 0x9e3779b9 + (hashVal << 6) + (hashVal >> 2); + } + } + return hashVal; + } +} diff --git a/libnd4j/include/array/impl/ShapeDescriptor.cpp b/libnd4j/include/array/impl/ShapeDescriptor.cpp index 356177163..3891fcbb8 100644 --- a/libnd4j/include/array/impl/ShapeDescriptor.cpp +++ b/libnd4j/include/array/impl/ShapeDescriptor.cpp @@ -22,127 +22,89 @@ #include #include -using namespace nd4j; +namespace nd4j { ////////////////////////////////////////////////////////////////////////// // equal to operator -bool ShapeDescriptor::operator==(const ShapeDescriptor& other) const { + bool ShapeDescriptor::operator==(const ShapeDescriptor &other) const { - if(_empty != other._empty) - return false; - if(_rank != other._rank) - return false; - if(_order != other._order) - return false; - if(_dataType != other._dataType) - return false; - if(_ews != other._ews) - return false; + if (_empty != other._empty) + return false; + if (_rank != other._rank) + return false; + if (_order != other._order) + return false; + if (_dataType != other._dataType) + return false; + if (_ews != other._ews) + return false; - if(_shape != other._shape) - return false; + if (_shape != other._shape) + return false; - if(_strides != other._strides) - return false; + if (_strides != other._strides) + return false; - return true; -} + return true; + } ////////////////////////////////////////////////////////////////////////// // less than operator -bool ShapeDescriptor::operator<(const ShapeDescriptor& other) const { - return std::tie(_empty, _rank, _dataType, _ews, _order, _shape, _strides) < std::tie(other._empty, other._rank, other._dataType, other._ews, other._order, other._shape, other._strides); -} + bool ShapeDescriptor::operator<(const ShapeDescriptor &other) const { + return std::tie(_empty, _rank, _dataType, _ews, _order, _shape, _strides) < + std::tie(other._empty, other._rank, other._dataType, other._ews, other._order, other._shape, + other._strides); + } -Nd4jLong* ShapeDescriptor::toShapeInfo() const { - if (_empty) { - if (_rank == 0) - return ShapeBuilders::emptyShapeInfo(_dataType); - else { - return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape); + Nd4jLong *ShapeDescriptor::toShapeInfo() const { + if (_empty) { + if (_rank == 0) + return ShapeBuilders::emptyShapeInfo(_dataType); + else { + return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape); + } + } + + + switch (_rank) { + case 0: { + auto shapeInfo = ShapeBuilders::createScalarShapeInfo(_dataType); + shapeInfo[2] = _ews; + return shapeInfo; + } + case 1: { + auto shapeInfo = ShapeBuilders::createVectorShapeInfo(_dataType, _shape[0]); + shapeInfo[2 + _rank * 2] = _ews; + shapeInfo[2] = _strides[0]; + shapeInfo[2 + _rank * 2 + 1] = _order; + return shapeInfo; + } + default: { + auto shapeInfo = ShapeBuilders::createShapeInfo(_dataType, _order, _shape); + + for (int e = 0; e < _rank; e++) + shapeInfo[e + 1 + _rank] = _strides[e]; + + shapeInfo[2 + _rank * 2] = _ews; + + return shapeInfo; + } } } + ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const int rank) + : _dataType(type), _order(order), _rank(rank), _ews(1) { + _shape.resize(rank); + _strides.resize(rank); - switch (_rank) { - case 0: { - auto shapeInfo = ShapeBuilders::createScalarShapeInfo(_dataType); - shapeInfo[2] = _ews; - return shapeInfo; - } - case 1: { - auto shapeInfo = ShapeBuilders::createVectorShapeInfo(_dataType, _shape[0]); - shapeInfo[2 + _rank * 2] = _ews; - shapeInfo[2] = _strides[0]; - shapeInfo[2 + _rank * 2 + 1] = _order; - return shapeInfo; - } - default: { - auto shapeInfo = ShapeBuilders::createShapeInfo(_dataType, _order, _shape); + for (int e = 0; e < rank; e++) + _shape[e] = shape[e]; - for (int e = 0; e < _rank; e++) - shapeInfo[e + 1 + _rank] = _strides[e]; + if (order == 'c') + shape::calcStrides(_shape.data(), _shape.size(), _strides.data()); + else + shape::calcStridesFortran(_shape.data(), _shape.size(), _strides.data()); - shapeInfo[2 + _rank * 2] = _ews; - - return shapeInfo; - } - } -} - -ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const int rank) : _dataType(type), _order(order), _rank(rank), _ews(1){ - _shape.resize(rank); - _strides.resize(rank); - - for (int e = 0; e < rank; e++) - _shape[e] = shape[e]; - - if (order == 'c') - shape::calcStrides(_shape.data(), _shape.size(), _strides.data()); - else - shape::calcStridesFortran(_shape.data(), _shape.size(), _strides.data()); - - - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; - } - } -} - -ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const Nd4jLong *strides, const int rank, Nd4jLong ews, const bool empty) { - _shape.resize(rank); - _strides.resize(rank); - - _dataType = type; - _order = order; - _rank = rank; - _empty = empty; - _ews = ews; - - for (int e = 0; e < rank; e++) - _shape[e] = shape[e]; - - for (int e = 0; e < rank; e++) - _strides[e] = strides[e]; - - - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; - } - } -} - -////////////////////////////////////////////////////////////////////////// -ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape): _dataType(type), _order(order), _shape(shape) { - _rank = shape.size(); - _ews = 1; - - if (_rank > 0) { - _strides.resize(_rank); for (auto v:_shape) { if (v == 0) { @@ -150,188 +112,269 @@ ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const st break; } } + } - // no point calculating strides for empty arrays - if (!_empty) { - if (order == 'c') - shape::calcStrides(_shape.data(), shape.size(), _strides.data()); - else - shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); - } else { - // all strides set to 0 - memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size()); + ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, + const Nd4jLong *strides, const int rank, Nd4jLong ews, const bool empty) { + _shape.resize(rank); + _strides.resize(rank); + + _dataType = type; + _order = order; + _rank = rank; + _empty = empty; + _ews = ews; + + for (int e = 0; e < rank; e++) + _shape[e] = shape[e]; + + for (int e = 0; e < rank; e++) + _strides[e] = strides[e]; + + + for (auto v:_shape) { + if (v == 0) { + _empty = true; + break; + } } } -} ////////////////////////////////////////////////////////////////////////// -ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::initializer_list &shape): _dataType(type), _order(order), _shape(shape) { - _rank = shape.size(); - _ews = 1; + ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape) + : _dataType(type), _order(order), _shape(shape) { + _rank = shape.size(); + _ews = 1; - _strides.resize(shape.size()); - if (order == 'c') - shape::calcStrides(_shape.data(), shape.size(), _strides.data()); - else - shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); + if (_rank > 0) { + _strides.resize(_rank); - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; + for (auto v:_shape) { + if (v == 0) { + _empty = true; + break; + } + } + + // no point calculating strides for empty arrays + if (!_empty) { + if (order == 'c') + shape::calcStrides(_shape.data(), shape.size(), _strides.data()); + else + shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); + } else { + // all strides set to 0 + memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size()); + } } } -} ////////////////////////////////////////////////////////////////////////// -ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape, const std::vector &strides, const Nd4jLong ews): ShapeDescriptor(type, order, shape, strides) { - _ews = ews; -} + ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, + const std::initializer_list &shape) : _dataType(type), _order(order), + _shape(shape) { + _rank = shape.size(); + _ews = 1; -ShapeDescriptor::ShapeDescriptor(const DataType type, const Nd4jLong length) : _dataType(type), _ews(1), _order('c'), _rank(1), _empty(false) { - _shape = {length}; - _strides = {1}; -} - -ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) { - _order = shape::order(shapeInfo); - _ews = shape::elementWiseStride(shapeInfo); - _rank = shape::rank(shapeInfo); - - if (inheritDtype) - _dataType = ArrayOptions::dataType(shapeInfo); - - _empty = shape::isEmpty(shapeInfo); - - for (int e = 0; e < _rank; e++) { - _shape.emplace_back(shapeInfo[e + 1]); - if (shapeInfo[e + 1] == 0) - _empty = true; - } - - for (int e = 0; e < _rank; e++) - _strides.emplace_back(shapeInfo[e + 1 + _rank]); -} - -ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const nd4j::DataType dtypeOverride) : ShapeDescriptor::ShapeDescriptor(shapeInfo, false) { - _dataType = dtypeOverride; -} - -ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride) : ShapeDescriptor::ShapeDescriptor(shapeInfo, ArrayOptions::dataType(dtypeOverride)) { - // -} - -ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride, const Nd4jLong *orderOverride) : ShapeDescriptor::ShapeDescriptor(shapeInfo, ArrayOptions::dataType(dtypeOverride)) { - _order = shape::order(orderOverride); -} - -int ShapeDescriptor::rank() const { - return _rank; -} - -Nd4jLong ShapeDescriptor::ews() const { - return _ews; -} - -Nd4jLong ShapeDescriptor::arrLength() const { - - Nd4jLong len = 1; - for(const auto& dim : const_cast(this)->shape()) - len *= dim; - return len; -} - -char ShapeDescriptor::order() const { - return _order; -} - -DataType ShapeDescriptor::dataType() const { - return _dataType; -} - -bool ShapeDescriptor::isEmpty() const { - return _empty; -} -std::vector& ShapeDescriptor::shape() { - return _shape; -} - -std::vector& ShapeDescriptor::strides() { - return _strides; -} - -ShapeDescriptor::ShapeDescriptor(const ShapeDescriptor &other) { - _rank = other._rank; - _ews = other._ews; - _empty = other._empty; - _dataType = other._dataType; - _order = other._order; - _shape = other._shape; - _strides = other._strides; -} - -////////////////////////////////////////////////////////////////////////// -ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape, const std::vector &strides): _dataType(type), _order(order), _shape(shape) { - - if (strides.empty() && !shape.empty()) { _strides.resize(shape.size()); if (order == 'c') shape::calcStrides(_shape.data(), shape.size(), _strides.data()); else shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); - } - else { - _strides = strides; - } - - for (auto v:_shape) { - if (v == 0) { - _empty = true; - break; + for (auto v:_shape) { + if (v == 0) { + _empty = true; + break; + } } } -} -ShapeDescriptor ShapeDescriptor::emptyDescriptor(const DataType type) { - ShapeDescriptor descriptor; - descriptor._dataType = type; - descriptor._empty = true; - descriptor._rank = 0; - descriptor._order = 'c'; - descriptor._ews = 1; - - return descriptor; -} - -ShapeDescriptor ShapeDescriptor::scalarDescriptor(const DataType type) { - ShapeDescriptor descriptor; - descriptor._dataType = type; - descriptor._empty = false; - descriptor._rank = 0; - descriptor._order = 'c'; - descriptor._ews = 1; - - return descriptor; -} - -ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, const DataType type) { - ShapeDescriptor descriptor; - descriptor._dataType = type; - descriptor._shape.emplace_back(length); - - if (length > 0) - descriptor._strides.emplace_back(1); - else { - descriptor._strides.emplace_back(0); - descriptor._empty = true; +////////////////////////////////////////////////////////////////////////// + ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape, + const std::vector &strides, const Nd4jLong ews) : ShapeDescriptor(type, + order, + shape, + strides) { + _ews = ews; } - descriptor._order = 'c'; - descriptor._ews = 1; - descriptor._rank = 1; + ShapeDescriptor::ShapeDescriptor(const DataType type, const Nd4jLong length) : _dataType(type), _ews(1), + _order('c'), _rank(1), + _empty(false) { + _shape = {length}; + _strides = {1}; + } - return descriptor; + ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) { + _order = shape::order(shapeInfo); + _ews = shape::elementWiseStride(shapeInfo); + _rank = shape::rank(shapeInfo); + + if (inheritDtype) + _dataType = ArrayOptions::dataType(shapeInfo); + + _empty = shape::isEmpty(shapeInfo); + + for (int e = 0; e < _rank; e++) { + _shape.emplace_back(shapeInfo[e + 1]); + if (shapeInfo[e + 1] == 0) + _empty = true; + } + + for (int e = 0; e < _rank; e++) + _strides.emplace_back(shapeInfo[e + 1 + _rank]); + } + + ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const nd4j::DataType dtypeOverride) + : ShapeDescriptor::ShapeDescriptor(shapeInfo, false) { + _dataType = dtypeOverride; + } + + ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride) + : ShapeDescriptor::ShapeDescriptor(shapeInfo, ArrayOptions::dataType(dtypeOverride)) { + // + } + + ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride, + const Nd4jLong *orderOverride) : ShapeDescriptor::ShapeDescriptor(shapeInfo, + ArrayOptions::dataType( + dtypeOverride)) { + _order = shape::order(orderOverride); + } + + int ShapeDescriptor::rank() const { + return _rank; + } + + Nd4jLong ShapeDescriptor::ews() const { + return _ews; + } + + Nd4jLong ShapeDescriptor::arrLength() const { + + Nd4jLong len = 1; + for (const auto &dim : const_cast(this)->shape()) + len *= dim; + return len; + } + + char ShapeDescriptor::order() const { + return _order; + } + + DataType ShapeDescriptor::dataType() const { + return _dataType; + } + + bool ShapeDescriptor::isEmpty() const { + return _empty; + } + + std::vector &ShapeDescriptor::shape() { + return _shape; + } + + std::vector &ShapeDescriptor::strides() { + return _strides; + } + + ShapeDescriptor::ShapeDescriptor(const ShapeDescriptor &other) { + _rank = other._rank; + _ews = other._ews; + _empty = other._empty; + _dataType = other._dataType; + _order = other._order; + _shape = other._shape; + _strides = other._strides; + } + +////////////////////////////////////////////////////////////////////////// + ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector &shape, + const std::vector &strides) : _dataType(type), _order(order), + _shape(shape) { + + if (strides.empty() && !shape.empty()) { + _strides.resize(shape.size()); + if (order == 'c') + shape::calcStrides(_shape.data(), shape.size(), _strides.data()); + else + shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data()); + } else { + _strides = strides; + } + + + for (auto v:_shape) { + if (v == 0) { + _empty = true; + break; + } + } + } + + ShapeDescriptor ShapeDescriptor::emptyDescriptor(const DataType type) { + ShapeDescriptor descriptor; + descriptor._dataType = type; + descriptor._empty = true; + descriptor._rank = 0; + descriptor._order = 'c'; + descriptor._ews = 1; + + return descriptor; + } + + ShapeDescriptor ShapeDescriptor::scalarDescriptor(const DataType type) { + ShapeDescriptor descriptor; + descriptor._dataType = type; + descriptor._empty = false; + descriptor._rank = 0; + descriptor._order = 'c'; + descriptor._ews = 1; + + return descriptor; + } + + ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, const DataType type) { + ShapeDescriptor descriptor; + descriptor._dataType = type; + descriptor._shape.emplace_back(length); + + if (length > 0) + descriptor._strides.emplace_back(1); + else { + descriptor._strides.emplace_back(0); + descriptor._empty = true; + } + + descriptor._order = 'c'; + descriptor._ews = 1; + descriptor._rank = 1; + + return descriptor; + } +} + +namespace std { + size_t hash::operator()(const nd4j::ShapeDescriptor &k) const { + auto res = std::hash()(k.arrLength()); + res ^= std::hash()(k.order()) + 0x9e3779b9 + (res << 6) + (res >> 2); + res ^= k.dataType() + 0x9e3779b9 + (res << 6) + (res >> 2); + res ^= std::hash()(k.rank()) + 0x9e3779b9 + (res << 6) + (res >> 2); + res ^= std::hash()(k.ews()) + 0x9e3779b9 + (res << 6) + (res >> 2); + auto shapes = const_cast(k).shape(); + auto strides = const_cast(k).strides(); + for (auto s: shapes) { + res ^= std::hash()(s) + 0x9e3779b9 + (res << 6) + (res >> 2); + } + + for (auto s: strides) { + res ^= std::hash()(s) + 0x9e3779b9 + (res << 6) + (res >> 2); + } + + return res; + } } + diff --git a/libnd4j/include/array/impl/TadDescriptor.cpp b/libnd4j/include/array/impl/TadDescriptor.cpp index b6c8ba69d..a5043bb7c 100644 --- a/libnd4j/include/array/impl/TadDescriptor.cpp +++ b/libnd4j/include/array/impl/TadDescriptor.cpp @@ -65,11 +65,30 @@ namespace nd4j { return _axis; } - ShapeDescriptor& TadDescriptor::originalShape() { + ShapeDescriptor& TadDescriptor::originalShape(){ + return _originalShape; + } + + ShapeDescriptor const& TadDescriptor::originalShapeConst() const{ return _originalShape; } bool TadDescriptor::areUnitiesinShape() const { return _unitiesInShape; } +} + +namespace std { + size_t hash::operator()(const nd4j::TadDescriptor &k) const { + // Compute individual hash values for first, + // second and third and combine them using XOR + // and bit shifting: + auto res = std::hash()((int)k.areUnitiesinShape()); + res ^= std::hash()(k.originalShapeConst()) + 0x9e3779b9 + (res << 6) + (res >> 2); + auto axes = const_cast(k).axis(); + for (auto a: axes) { + res ^= std::hash()(a) + 0x9e3779b9 + (res << 6) + (res >> 2); + } + return res; + } } \ No newline at end of file diff --git a/libnd4j/include/cnpy/cnpy.h b/libnd4j/include/cnpy/cnpy.h index 06ff3336d..d66320cae 100644 --- a/libnd4j/include/cnpy/cnpy.h +++ b/libnd4j/include/cnpy/cnpy.h @@ -38,7 +38,7 @@ #include #include #include -#include +#include #include #include #include @@ -69,7 +69,7 @@ namespace cnpy { } }; - struct ND4J_EXPORT npz_t : public std::map { + struct ND4J_EXPORT npz_t : public std::unordered_map { void destruct() { npz_t::iterator it = this->begin(); for(; it != this->end(); ++it) (*it).second.destruct(); diff --git a/libnd4j/include/graph/ExecutionResult.h b/libnd4j/include/graph/ExecutionResult.h index b1a1b1737..850974943 100644 --- a/libnd4j/include/graph/ExecutionResult.h +++ b/libnd4j/include/graph/ExecutionResult.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -33,8 +34,8 @@ namespace nd4j { class ExecutionResult { private: std::vector _variables; - std::map _stringIdMap; - std::map, Variable *> _pairIdMap; + MAP_IMPL _stringIdMap; + MAP_IMPL, Variable *> _pairIdMap; // this flag is used to optionally release variables bool _releasable = false; diff --git a/libnd4j/include/graph/FlowPath.h b/libnd4j/include/graph/FlowPath.h index fae19fd0b..3f72c695c 100644 --- a/libnd4j/include/graph/FlowPath.h +++ b/libnd4j/include/graph/FlowPath.h @@ -21,6 +21,8 @@ #ifndef LIBND4J_FLOWPATH_H #define LIBND4J_FLOWPATH_H +#include +#include #include #include #include @@ -32,8 +34,8 @@ namespace nd4j { namespace graph { class ND4J_EXPORT FlowPath { private: - std::map _states; - std::map _frames; + MAP_IMPL _states; + MAP_IMPL _frames; void ensureNode(int nodeId); void ensureFrame(int nodeId); diff --git a/libnd4j/include/graph/Graph.h b/libnd4j/include/graph/Graph.h index 00efb3c52..5145ba9b0 100644 --- a/libnd4j/include/graph/Graph.h +++ b/libnd4j/include/graph/Graph.h @@ -24,6 +24,7 @@ #include #include #include +#include //#include #include #include @@ -50,10 +51,10 @@ namespace nd4j { // vector holds ID's of top nodes only std::vector *_nodes; - std::map *_mapped; + MAP_IMPL *_mapped; - std::map *> *_onion; - std::map _unmapped; + MAP_IMPL *> *_onion; + MAP_IMPL _unmapped; std::vector _unmappedMap; // macOS? std::mutex _mutexPreprocessing; @@ -63,7 +64,7 @@ namespace nd4j { std::vector _autos; - std::map _mappedScopes; + MAP_IMPL _mappedScopes; std::vector _scopes; //////////////////////////////////////// @@ -124,13 +125,13 @@ namespace nd4j { * * @return */ - std::map *> *getOnion(); + MAP_IMPL *> *getOnion(); /** * This method returns map of all nodes of the graph * @return */ - std::map *getMapped(); + MAP_IMPL* getMapped(); /** * This method returns outputs of this graph @@ -233,7 +234,7 @@ namespace nd4j { return &_output; } - FORCEINLINE std::map* scopes() { + FORCEINLINE MAP_IMPL* scopes() { return &_mappedScopes; } diff --git a/libnd4j/include/graph/GraphHolder.h b/libnd4j/include/graph/GraphHolder.h index f740ad4ca..3465d182e 100644 --- a/libnd4j/include/graph/GraphHolder.h +++ b/libnd4j/include/graph/GraphHolder.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -30,9 +31,9 @@ namespace nd4j { class ND4J_EXPORT GraphHolder { private: static GraphHolder *_INSTANCE; - std::map _graphF; + MAP_IMPL _graphF; - std::map _locks; + MAP_IMPL _locks; GraphHolder() = default; ~GraphHolder() = default; diff --git a/libnd4j/include/graph/GraphState.h b/libnd4j/include/graph/GraphState.h index 52c6f9e16..6fc553a09 100644 --- a/libnd4j/include/graph/GraphState.h +++ b/libnd4j/include/graph/GraphState.h @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -43,7 +44,7 @@ namespace graph { Nd4jLong _id = 0; // map of scopes. Scope id is used as key, since it's referred in calls later anyway - std::map _scopes; + MAP_IMPL _scopes; // this variable space holds temp references VariableSpace _variableSpace; diff --git a/libnd4j/include/graph/SessionLocalStorage.h b/libnd4j/include/graph/SessionLocalStorage.h index dd8051fc2..5d6299938 100644 --- a/libnd4j/include/graph/SessionLocalStorage.h +++ b/libnd4j/include/graph/SessionLocalStorage.h @@ -22,6 +22,8 @@ #define LIBND4J_SESSIONLOCALSTORAGE_H #include +#include +#include #include "VariableSpace.h" #include "Context.h" #include "Stash.h" @@ -32,8 +34,8 @@ namespace nd4j{ class ND4J_EXPORT SessionLocalStorage { protected: std::atomic _sessionCounter; - std::map _threadSession; - std::map _threadVariableSpace; + MAP_IMPL _threadSession; + MAP_IMPL _threadVariableSpace; VariableSpace* _variableSpace; Stash* _stash; diff --git a/libnd4j/include/graph/Stash.h b/libnd4j/include/graph/Stash.h index 83a7ec066..b44396819 100644 --- a/libnd4j/include/graph/Stash.h +++ b/libnd4j/include/graph/Stash.h @@ -23,9 +23,11 @@ //#include #include -#include +#include +#include #include #include +#include #include namespace nd4j { @@ -34,11 +36,34 @@ namespace nd4j { int _node; std::string _name; public: - KeyPair(int node = 0, const char * name = nullptr); + KeyPair(int node = 0, const char *name = nullptr); - bool operator<(const KeyPair& other) const; + bool operator<(const KeyPair &other) const; + + bool operator==(const KeyPair &other) const { + return _node == other._node; + } + + int key() const { return _node; } + std::string name() const { return _name; } }; + } +} +#ifndef __JAVACPP_HACK__ + +namespace std { + template <> + class ND4J_EXPORT hash { + public: + size_t operator()(const nd4j::graph::KeyPair& k) const; + }; +}; + +#endif + +namespace nd4j { + namespace graph { class ND4J_EXPORT Stash { protected: std::map _stash; @@ -60,6 +85,7 @@ namespace nd4j { void clear(); }; } + } diff --git a/libnd4j/include/graph/Variable.h b/libnd4j/include/graph/Variable.h index 60f977e97..76ce62fcf 100644 --- a/libnd4j/include/graph/Variable.h +++ b/libnd4j/include/graph/Variable.h @@ -29,6 +29,31 @@ #include #include +#ifndef __JAVACPP_HACK__ + +namespace std { + + template <> + class ND4J_EXPORT hash> { + public: + size_t operator()(const std::pair& k) const; + }; + + template <> + class ND4J_EXPORT hash { + public: + size_t operator()(const bfloat16& k) const; + }; + + template <> + class ND4J_EXPORT hash { + public: + size_t operator()(const float16& k) const; + }; +}; + +#endif + namespace nd4j { namespace graph { class ND4J_EXPORT Variable { diff --git a/libnd4j/include/graph/impl/Graph.cpp b/libnd4j/include/graph/impl/Graph.cpp index 4b337dd0d..3950d5f8a 100644 --- a/libnd4j/include/graph/impl/Graph.cpp +++ b/libnd4j/include/graph/impl/Graph.cpp @@ -51,7 +51,7 @@ namespace nd4j { Nd4jLong lastStep = 0L; std::vector shapes; - std::map, Nd4jLong*> shapesMap; + MAP_IMPL, Nd4jLong*> shapesMap; int cntFD = 0; @@ -249,11 +249,11 @@ namespace nd4j { return res; } - std::map * Graph::getMapped() { + MAP_IMPL * Graph::getMapped() { return _mapped; } - std::map *>* Graph::getOnion() { + MAP_IMPL *>* Graph::getOnion() { return _onion; } @@ -518,7 +518,7 @@ namespace nd4j { return ND4J_STATUS_OK; } - typename std::map::iterator fit; + typename MAP_IMPL::iterator fit; int cnts = 0; for ( fit = _unmapped.begin(); fit != _unmapped.end(); fit++ ) { int tK = fit->first; @@ -535,7 +535,7 @@ namespace nd4j { std::vector queue; // first pass for unmapped nodes, we try to build tale here - typename std::map::iterator it; + typename MAP_IMPL::iterator it; int cntf = 0; nd4j_debug("-----------\n",""); for ( it = _unmapped.begin(); it != _unmapped.end(); it++ ) { @@ -866,8 +866,8 @@ namespace nd4j { } Graph::Graph(const FlatGraph *flatGraph, VariableSpace *variableSpace) { - this->_onion = new std::map *>(); - this->_mapped = new std::map (); + this->_onion = new MAP_IMPL *>(); + this->_mapped = new MAP_IMPL (); this->_nodes = new std::vector(); this->_variableSpace = variableSpace == nullptr ? new VariableSpace() : variableSpace; bool trusted = flatGraph != nullptr; diff --git a/libnd4j/include/graph/impl/Stash.cpp b/libnd4j/include/graph/impl/Stash.cpp index 5a48062f7..c6a573605 100644 --- a/libnd4j/include/graph/impl/Stash.cpp +++ b/libnd4j/include/graph/impl/Stash.cpp @@ -20,6 +20,16 @@ #include + +namespace std { + size_t hash::operator()(const nd4j::graph::KeyPair& k) const { + using std::hash; + auto res = std::hash()(k.name()); + res ^= std::hash()(k.key()) + 0x9e3779b9 + (res << 6) + (res >> 2); + return res; + } +} + namespace nd4j { namespace graph { nd4j::graph::KeyPair::KeyPair(int node, const char * name) { diff --git a/libnd4j/include/graph/impl/Variable.cpp b/libnd4j/include/graph/impl/Variable.cpp index c2c5ff61f..9ff7fbf37 100644 --- a/libnd4j/include/graph/impl/Variable.cpp +++ b/libnd4j/include/graph/impl/Variable.cpp @@ -340,4 +340,21 @@ namespace nd4j { } } } +} + +namespace std { + + size_t hash>::operator()(const std::pair& k) const { + auto v = std::hash()(k.first); + v ^= std::hash()(k.second) + 0x9e3779b9 + (v << 6) + (v >> 2); + return v; + } + + size_t hash::operator()(const bfloat16& k) const { + return std::hash()((float)k); + } + + size_t hash::operator()(const float16& k) const { + return std::hash()((float)k); + } } \ No newline at end of file diff --git a/libnd4j/include/helpers/ConstantHelper.h b/libnd4j/include/helpers/ConstantHelper.h index 6aad7c387..07ae6d156 100644 --- a/libnd4j/include/helpers/ConstantHelper.h +++ b/libnd4j/include/helpers/ConstantHelper.h @@ -38,7 +38,7 @@ namespace nd4j { static ConstantHelper* _INSTANCE; ConstantHelper(); - std::vector> _cache; + std::vector> _cache; // tracking of per-device constant memory buffers (CUDA only atm) std::vector _devicePointers; diff --git a/libnd4j/include/helpers/ConstantShapeHelper.h b/libnd4j/include/helpers/ConstantShapeHelper.h index d5ea9abe9..3184a3675 100644 --- a/libnd4j/include/helpers/ConstantShapeHelper.h +++ b/libnd4j/include/helpers/ConstantShapeHelper.h @@ -38,7 +38,7 @@ namespace nd4j { static ConstantShapeHelper *_INSTANCE; std::mutex _mutex; - std::vector> _cache; + std::vector> _cache; ConstantShapeHelper(); diff --git a/libnd4j/include/helpers/ConstantTadHelper.h b/libnd4j/include/helpers/ConstantTadHelper.h index 79ee7dcd4..3a79a74e3 100644 --- a/libnd4j/include/helpers/ConstantTadHelper.h +++ b/libnd4j/include/helpers/ConstantTadHelper.h @@ -38,7 +38,7 @@ namespace nd4j { static ConstantTadHelper *_INSTANCE; std::mutex _mutex; - std::vector> _cache; + std::vector> _cache; ConstantTadHelper(); public: diff --git a/libnd4j/include/helpers/cpu/ConstantHelper.cpp b/libnd4j/include/helpers/cpu/ConstantHelper.cpp index 2ba2cc4e0..f6981d582 100644 --- a/libnd4j/include/helpers/cpu/ConstantHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantHelper.cpp @@ -29,12 +29,13 @@ #include namespace nd4j { + ConstantHelper::ConstantHelper() { int numDevices = getNumberOfDevices(); _cache.resize(numDevices); _counters.resize(numDevices); for (int e = 0; e < numDevices; e++) { - std::map map; + MAP_IMPL map; _cache[e] = map; _counters[e] = 0L; diff --git a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp index bcedd727e..5ab1e91f7 100644 --- a/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantShapeHelper.cpp @@ -29,7 +29,7 @@ namespace nd4j { ConstantShapeHelper::ConstantShapeHelper() { _cache.resize(32); for (int e = 0; e < 32; e++) { - std::map cache; + MAP_IMPL cache; _cache[e] = cache; } } diff --git a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp index d48cfca61..9c34cc475 100644 --- a/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp +++ b/libnd4j/include/helpers/cpu/ConstantTadHelper.cpp @@ -24,10 +24,11 @@ #ifndef __CUDABLAS__ + namespace nd4j { ConstantTadHelper::ConstantTadHelper() { - std::map pack; + MAP_IMPL pack; _cache.emplace_back(pack); } diff --git a/libnd4j/include/helpers/cuda/ConstantHelper.cu b/libnd4j/include/helpers/cuda/ConstantHelper.cu index 47e276f4a..678988dd9 100644 --- a/libnd4j/include/helpers/cuda/ConstantHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantHelper.cu @@ -70,7 +70,7 @@ namespace nd4j { throw cuda_exception::build("cudaSetDevice failed", res); auto constant = getConstantSpace(); - std::map devCache; + MAP_IMPL devCache; _devicePointers[e] = constant; _deviceOffsets[e] = 0; diff --git a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu index 4f7a4a485..96f2774cd 100644 --- a/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantShapeHelper.cu @@ -32,7 +32,7 @@ namespace nd4j { _cache.resize(numDevices); for (int e = 0; e < numDevices; e++) { - std::map cache; + MAP_IMPL cache; _cache[e] = cache; } } diff --git a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu index 747e295e2..a1cd3e89f 100644 --- a/libnd4j/include/helpers/cuda/ConstantTadHelper.cu +++ b/libnd4j/include/helpers/cuda/ConstantTadHelper.cu @@ -31,7 +31,7 @@ namespace nd4j { auto numDevices = AffinityManager::numberOfDevices(); for (int e = 0; e < numDevices; e++) { - std::map pack; + MAP_IMPL pack; _cache.emplace_back(pack); } } diff --git a/libnd4j/include/memory/MemoryRegistrator.h b/libnd4j/include/memory/MemoryRegistrator.h index 53e97d35e..a286923ad 100644 --- a/libnd4j/include/memory/MemoryRegistrator.h +++ b/libnd4j/include/memory/MemoryRegistrator.h @@ -22,6 +22,8 @@ #define LIBND4J_MEMORYREGISTRATOR_H #include "Workspace.h" +#include +#include #include #include #include @@ -32,7 +34,7 @@ namespace nd4j { protected: static MemoryRegistrator* _INSTANCE; Workspace* _workspace; - std::map _footprint; + MAP_IMPL _footprint; std::mutex _lock; MemoryRegistrator(); diff --git a/libnd4j/include/ops/declarable/OpDescriptor.h b/libnd4j/include/ops/declarable/OpDescriptor.h index 302559ad8..72f09f96d 100644 --- a/libnd4j/include/ops/declarable/OpDescriptor.h +++ b/libnd4j/include/ops/declarable/OpDescriptor.h @@ -23,7 +23,6 @@ #include #include -#include #include #include #include @@ -84,8 +83,8 @@ namespace nd4j { std::vector _allowedOuts; // optional per-input configuration - std::map> _outputTypes; - std::map> _inputTypes; + MAP_IMPL> _outputTypes; + MAP_IMPL> _inputTypes; // field for ops that allow data type override at runtime diff --git a/libnd4j/include/ops/declarable/OpRegistrator.h b/libnd4j/include/ops/declarable/OpRegistrator.h index 789b361f3..cd747498f 100644 --- a/libnd4j/include/ops/declarable/OpRegistrator.h +++ b/libnd4j/include/ops/declarable/OpRegistrator.h @@ -33,6 +33,26 @@ #include #include +#ifndef __JAVACPP_HACK__ + +namespace std { + + template <> + class hash> { + public: + size_t operator()(const std::pair& k) const; + }; + + template <> + class hash> { + public: + size_t operator()(const std::pair& k) const; + }; +}; + +#endif + + namespace nd4j { namespace ops { /** @@ -59,16 +79,16 @@ namespace nd4j { #endif }; - std::map _msvc; + MAP_IMPL _msvc; // pointers to our operations - std::map _declarablesLD; - std::map _declarablesD; + MAP_IMPL _declarablesLD; + MAP_IMPL _declarablesD; std::vector _uniqueD; // pointers to platform-specific helpers - std::map, nd4j::ops::platforms::PlatformHelper*> _helpersLH; - std::map, nd4j::ops::platforms::PlatformHelper*> _helpersH; + MAP_IMPL, nd4j::ops::platforms::PlatformHelper*> _helpersLH; + MAP_IMPL, nd4j::ops::platforms::PlatformHelper*> _helpersH; std::vector _uniqueH; std::mutex _locker; diff --git a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp index 08aafc98c..1679557af 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp @@ -350,7 +350,7 @@ namespace helpers { // if input is a vector: (as if in doc sample) //int idx = static_cast((*indices)(0.)); - std::map> idxs;//(indices->lengthOf()); + MAP_IMPL> idxs;//(indices->lengthOf()); for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) idxs[indices->e(e)].push_back(e); @@ -400,7 +400,7 @@ namespace helpers { static void unsortedSegmentMinFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { // if input is a vector: (as if in doc sample) //int idx = static_cast((*indices)(0.)); - std::map> idxs;//(indices->lengthOf()); + MAP_IMPL> idxs;//(indices->lengthOf()); for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) idxs[indices->e(e)].push_back(e); @@ -452,7 +452,7 @@ namespace helpers { BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); void unsortedSegmentMeanFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - std::map> idxs;//(indices->lengthOf()); + MAP_IMPL> idxs;//(indices->lengthOf()); for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) idxs[indices->e(e)].push_back(e); @@ -494,7 +494,7 @@ namespace helpers { } void unsortedSegmentSumFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - std::map> idxs;//(indices->lengthOf()); + MAP_IMPL> idxs;//(indices->lengthOf()); for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) idxs[indices->e(e)].push_back(e); @@ -534,7 +534,7 @@ namespace helpers { template void unsortedSegmentProdFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - std::map> idxs;//(indices->lengthOf()); + MAP_IMPL> idxs;//(indices->lengthOf()); for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) idxs[indices->e(e)].push_back(e); @@ -575,7 +575,7 @@ namespace helpers { BUILD_SINGLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - std::map> idxs;//(indices->lengthOf()); + MAP_IMPL> idxs;//(indices->lengthOf()); for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) idxs[indices->e(e)].push_back(e); @@ -719,7 +719,7 @@ namespace helpers { // segmen mean int segmentMeanFunctorBP(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { int numClasses = output->sizeAt(0); - std::map classCount;//(numClasses); + MAP_IMPL classCount;//(numClasses); for (Nd4jLong count = 0; count < numClasses; ++count) { classCount[count] = 0; @@ -931,7 +931,7 @@ namespace helpers { int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - std::map classCount;//(numClasses); + MAP_IMPL classCount;//(numClasses); for (Nd4jLong count = 0; count < numOfClasses; ++count) { classCount[count] = 0; @@ -1040,7 +1040,7 @@ namespace helpers { // template int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - std::map classCount;//(numClasses); + MAP_IMPL classCount;//(numClasses); for (Nd4jLong count = 0; count < numOfClasses; ++count) { classCount[count] = 0; diff --git a/libnd4j/include/ops/declarable/helpers/impl/unique.cpp b/libnd4j/include/ops/declarable/helpers/impl/unique.cpp index 3bcdea865..d9bee23dc 100644 --- a/libnd4j/include/ops/declarable/helpers/impl/unique.cpp +++ b/libnd4j/include/ops/declarable/helpers/impl/unique.cpp @@ -21,6 +21,7 @@ #include #include #include +#include namespace nd4j { namespace ops { @@ -48,13 +49,12 @@ namespace helpers { BUILD_SINGLE_TEMPLATE(template Nd4jLong uniqueCount_, (NDArray* input), LIBND4J_TYPES); - template static Nd4jStatus uniqueFunctor_(NDArray* input, NDArray* values, NDArray* indices, NDArray* counts) { std::vector valuesVector; - std::map indicesMap; - std::map countsMap; + MAP_IMPL indicesMap; + MAP_IMPL countsMap; for (int e = 0; e < input->lengthOf(); e++) { T v = input->e(e); diff --git a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp index 09e4ec58f..64abc4a3a 100644 --- a/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp +++ b/libnd4j/include/ops/declarable/impl/OpRegistrator.cpp @@ -130,7 +130,7 @@ namespace nd4j { _locker.lock(); if (!isInit) { - for (std::map::iterator it=_declarablesD.begin(); it!=_declarablesD.end(); ++it) { + for (MAP_IMPL::iterator it=_declarablesD.begin(); it!=_declarablesD.end(); ++it) { std::string op = it->first + ":" + local_to_string(it->second->getOpDescriptor()->getHash()) + ":" + local_to_string(it->second->getOpDescriptor()->getNumberOfInputs()) + ":" @@ -261,3 +261,19 @@ namespace nd4j { } } +namespace std { + size_t hash>::operator()(const std::pair& k) const { + using std::hash; + auto res = std::hash()(k.first); + res ^= std::hash()((int) k.second) + 0x9e3779b9 + (res << 6) + (res >> 2); + return res; + } + + size_t hash>::operator()(const std::pair& k) const { + using std::hash; + auto res = std::hash()(k.first); + res ^= std::hash()((int) k.second) + 0x9e3779b9 + (res << 6) + (res >> 2); + return res; + } +} + diff --git a/libnd4j/include/pointercast.h b/libnd4j/include/pointercast.h index e080b33b6..66b28693f 100644 --- a/libnd4j/include/pointercast.h +++ b/libnd4j/include/pointercast.h @@ -60,5 +60,28 @@ typedef int Nd4jStatus; #define ND4J_STATUS_MAYBE 119 +#ifdef _MSC_VER + +#include +#define MAP_IMPL std::map + +#elif __clang__ + +#include +#define MAP_IMPL std::unordered_map + +#elif __GNUC__ + +#include +#define MAP_IMPL std::unordered_map + +#else + +#include +#define MAP_IMPL std::unordered_map + +#endif + + #endif //NATIVEOPERATIONS_POINTERCAST_H diff --git a/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp b/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp index e025aaead..aa1491b75 100644 --- a/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp @@ -104,6 +104,25 @@ TEST_F(ConstantShapeHelperTests, basic_test_1) { delete []ptr; } +TEST_F(ConstantShapeHelperTests, stress_test_1) { + + for (auto x = 0; x < 1000; x++) { + auto ptr = ShapeBuilders::createShapeInfo(nd4j::DataType::FLOAT32, 'c', {5, x + 10, x + 1}); + ShapeDescriptor descriptor(ptr); + ConstantShapeHelper::getInstance()->createShapeInfo(descriptor); + delete [] ptr; + } + ShapeDescriptor aShape(nd4j::DataType::FLOAT32, 'c', {(Nd4jLong)5, (Nd4jLong)382, (Nd4jLong)373}); +// nd4j_printf("%d\n", ConstantShapeHelper::getInstance()->cachedEntriesForDevice(0)); + + auto timeStart = std::chrono::system_clock::now(); + ASSERT_TRUE(ConstantShapeHelper::getInstance()->checkBufferExistenceForShapeInfo(aShape)); + auto timeEnd = std::chrono::system_clock::now(); + + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + nd4j_printf("Total time (us) %lld\n", outerTime); +} + TEST_F(ConstantShapeHelperTests, basic_test_3) { auto array = NDArrayFactory::create_('c', {128}); diff --git a/libnd4j/tests_cpu/layers_tests/StashTests.cpp b/libnd4j/tests_cpu/layers_tests/StashTests.cpp index 170c1763a..bfa1a6ac6 100644 --- a/libnd4j/tests_cpu/layers_tests/StashTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/StashTests.cpp @@ -71,17 +71,17 @@ TEST_F(StashTests, BasicTests_2) { auto cappa = NDArrayFactory::create_('c',{5, 5}); cappa->assign(3.0); - stash.storeArray(1, "alpha1", alpha); - stash.storeArray(1, "alpha2", beta); - stash.storeArray(1, "alpha3", cappa); + stash.storeArray(1, "alpha", alpha); + stash.storeArray(1, "beta", beta); + stash.storeArray(1, "cappa", cappa); - ASSERT_FALSE(stash.checkStash(2, "alpha1")); - ASSERT_FALSE(stash.checkStash(2, "alpha2")); - ASSERT_FALSE(stash.checkStash(2, "alpha3")); + ASSERT_FALSE(stash.checkStash(2, "alpha")); + ASSERT_FALSE(stash.checkStash(2, "beta")); + ASSERT_FALSE(stash.checkStash(2, "cappa")); - ASSERT_TRUE(alpha == stash.extractArray(1, "alpha1")); - ASSERT_TRUE(beta == stash.extractArray(1, "alpha2")); - ASSERT_TRUE(cappa == stash.extractArray(1, "alpha3")); + ASSERT_TRUE(alpha == stash.extractArray(1, "alpha")); + ASSERT_TRUE(beta == stash.extractArray(1, "beta")); + ASSERT_TRUE(cappa == stash.extractArray(1, "cappa")); }