Shugeo unordered map (#256)

* Refactored usage of std::map to std::unordered_map instead.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Eliminated crash with wrong ShapeDescriptor hash.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Eliminated crash with TadDescriptor hash.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored Stash hash.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored hashes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored TadDescriptor hash and top_k mapping.

* Refactored hashes for ShapeDescriptor and TadDescriptor classes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored hash for ConstantDescriptor and ShapeDescriptor classes.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Fixed map using with cuda platform.

Signed-off-by: shugeo <sgazeos@gmail.com>

* - few rearrangements for hash functions
- shared openblas allowed

Signed-off-by: raver119 <raver119@gmail.com>

* exports

Signed-off-by: raver119 <raver119@gmail.com>

* exports

Signed-off-by: raver119 <raver119@gmail.com>

* Stash reverted to std::map

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Added additional test.

Signed-off-by: shugeo <sgazeos@gmail.com>

* different maps for different compilers

Signed-off-by: raver119 <raver119@gmail.com>

* missing include

Signed-off-by: raver119 <raver119@gmail.com>

* fix the leak

Signed-off-by: raver119 <raver119@gmail.com>

Co-authored-by: raver119 <raver119@gmail.com>
master
shugeo 2020-02-24 06:51:01 +02:00 committed by GitHub
parent a0ed5487ca
commit 1bb3ae4b03
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 621 additions and 351 deletions

View File

@ -71,8 +71,7 @@ if(NOT CUDA_BLAS)
# there's a chance, we have no BLAS provided externally
if ("${OPENBLAS_PATH}" STREQUAL "")
#we don't want static OpenBLAS on Apple
set(BLA_STATIC ON)
#we don't want OpenBLAS on Apple
if (NOT APPLE)
set(BLA_VENDOR "OpenBLAS")
endif()
@ -80,23 +79,8 @@ if(NOT CUDA_BLAS)
# look around for system blas instead
find_package(BLAS REQUIRED)
if (BLAS_FOUND)
message("Original library: ${BLAS_LIBRARIES}")
# workaround for for cmake being unable to find static blas library
SET(_TMP_B "")
if (APPLE)
string(REGEX REPLACE "\\.dylib$" ".lib" _TMP_B "${BLAS_LIBRARIES}")
elseif (WIN32)
string(REGEX REPLACE "\\.dll" ".lib" _TMP_B "${BLAS_LIBRARIES}")
else()
string(REGEX REPLACE "\\.so$" ".a" _TMP_B "${BLAS_LIBRARIES}")
endif()
set(BLAS_LIBRARIES "${_TMP_B}")
message("Found external BLAS implementation: ${BLAS_LIBRARIES} ")
add_definitions(-D__EXTERNAL_BLAS__=true)
elseif(WIN32)
message("BLAS not found, using downloaded OpenBLAS instead")
add_definitions(-D__EXTERNAL_BLAS__=true)
endif()
else()
# if we have externally provided OPENBLAS_PATH - let's use it

View File

@ -59,5 +59,17 @@ namespace nd4j {
};
}
#ifndef __JAVACPP_HACK__
namespace std {
template<>
class ND4J_EXPORT hash<nd4j::ConstantDescriptor> {
public:
size_t operator()(const nd4j::ConstantDescriptor &k) const;
};
}
#endif
#endif //DEV_TESTS_CONSTANTDESCRIPTOR_H

View File

@ -44,7 +44,7 @@ namespace nd4j {
nd4j::DataType _dtype;
// stored chunks
std::map<int, nd4j::NDArray*> _chunks;
MAP_IMPL<int, nd4j::NDArray*> _chunks;
// just a counter, for stored elements
std::atomic<int> _elements;

View File

@ -85,9 +85,19 @@ class ND4J_EXPORT ShapeDescriptor {
static ShapeDescriptor scalarDescriptor(const DataType type);
static ShapeDescriptor vectorDescriptor(const Nd4jLong length, const DataType type);
};
}
#ifndef __JAVACPP_HACK__
namespace std {
template<>
class ND4J_EXPORT hash<nd4j::ShapeDescriptor> {
public:
size_t operator()(const nd4j::ShapeDescriptor &k) const;
};
}
#endif
#endif //DEV_TESTS_SHAPEDESCRIPTOR_H

View File

@ -53,9 +53,22 @@ namespace nd4j {
std::vector<int>& axis();
ShapeDescriptor& originalShape();
ShapeDescriptor const& originalShapeConst() const;
bool areUnitiesinShape() const;
};
}
#ifndef __JAVACPP_HACK__
namespace std {
template<>
class ND4J_EXPORT hash<nd4j::TadDescriptor> {
public:
size_t operator()(const nd4j::TadDescriptor &k) const;
};
}
#endif
#endif //DEV_TESTS_TADDESCRIPTOR_H

View File

@ -75,3 +75,25 @@ namespace nd4j {
return isInteger() ? _integerValues.size() : isFloat() ? _floatValues.size() : 0L;
}
}
namespace std {
size_t hash<nd4j::ConstantDescriptor>::operator()(const nd4j::ConstantDescriptor &k) const {
using std::hash;
// Compute individual hash values for first,
// second and third and combine them using XOR
// and bit shifting:
size_t hashVal = 0;
size_t i = 0;
if (k.isInteger()) {
for (auto v: k.integerValues()) {
hashVal ^= std::hash<Nd4jLong>()(v) + 0x9e3779b9 + (hashVal << 6) + (hashVal >> 2);
}
}
else {
for (auto v: k.floatValues()) {
hashVal ^= std::hash<double>()(v) + 0x9e3779b9 + (hashVal << 6) + (hashVal >> 2);
}
}
return hashVal;
}
}

View File

@ -22,127 +22,89 @@
#include <shape.h>
#include <ShapeBuilders.h>
using namespace nd4j;
namespace nd4j {
//////////////////////////////////////////////////////////////////////////
// equal to operator
bool ShapeDescriptor::operator==(const ShapeDescriptor& other) const {
bool ShapeDescriptor::operator==(const ShapeDescriptor &other) const {
if(_empty != other._empty)
return false;
if(_rank != other._rank)
return false;
if(_order != other._order)
return false;
if(_dataType != other._dataType)
return false;
if(_ews != other._ews)
return false;
if (_empty != other._empty)
return false;
if (_rank != other._rank)
return false;
if (_order != other._order)
return false;
if (_dataType != other._dataType)
return false;
if (_ews != other._ews)
return false;
if(_shape != other._shape)
return false;
if (_shape != other._shape)
return false;
if(_strides != other._strides)
return false;
if (_strides != other._strides)
return false;
return true;
}
return true;
}
//////////////////////////////////////////////////////////////////////////
// less than operator
bool ShapeDescriptor::operator<(const ShapeDescriptor& other) const {
return std::tie(_empty, _rank, _dataType, _ews, _order, _shape, _strides) < std::tie(other._empty, other._rank, other._dataType, other._ews, other._order, other._shape, other._strides);
}
bool ShapeDescriptor::operator<(const ShapeDescriptor &other) const {
return std::tie(_empty, _rank, _dataType, _ews, _order, _shape, _strides) <
std::tie(other._empty, other._rank, other._dataType, other._ews, other._order, other._shape,
other._strides);
}
Nd4jLong* ShapeDescriptor::toShapeInfo() const {
if (_empty) {
if (_rank == 0)
return ShapeBuilders::emptyShapeInfo(_dataType);
else {
return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape);
Nd4jLong *ShapeDescriptor::toShapeInfo() const {
if (_empty) {
if (_rank == 0)
return ShapeBuilders::emptyShapeInfo(_dataType);
else {
return ShapeBuilders::emptyShapeInfo(_dataType, _order, _shape);
}
}
switch (_rank) {
case 0: {
auto shapeInfo = ShapeBuilders::createScalarShapeInfo(_dataType);
shapeInfo[2] = _ews;
return shapeInfo;
}
case 1: {
auto shapeInfo = ShapeBuilders::createVectorShapeInfo(_dataType, _shape[0]);
shapeInfo[2 + _rank * 2] = _ews;
shapeInfo[2] = _strides[0];
shapeInfo[2 + _rank * 2 + 1] = _order;
return shapeInfo;
}
default: {
auto shapeInfo = ShapeBuilders::createShapeInfo(_dataType, _order, _shape);
for (int e = 0; e < _rank; e++)
shapeInfo[e + 1 + _rank] = _strides[e];
shapeInfo[2 + _rank * 2] = _ews;
return shapeInfo;
}
}
}
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const int rank)
: _dataType(type), _order(order), _rank(rank), _ews(1) {
_shape.resize(rank);
_strides.resize(rank);
switch (_rank) {
case 0: {
auto shapeInfo = ShapeBuilders::createScalarShapeInfo(_dataType);
shapeInfo[2] = _ews;
return shapeInfo;
}
case 1: {
auto shapeInfo = ShapeBuilders::createVectorShapeInfo(_dataType, _shape[0]);
shapeInfo[2 + _rank * 2] = _ews;
shapeInfo[2] = _strides[0];
shapeInfo[2 + _rank * 2 + 1] = _order;
return shapeInfo;
}
default: {
auto shapeInfo = ShapeBuilders::createShapeInfo(_dataType, _order, _shape);
for (int e = 0; e < rank; e++)
_shape[e] = shape[e];
for (int e = 0; e < _rank; e++)
shapeInfo[e + 1 + _rank] = _strides[e];
if (order == 'c')
shape::calcStrides(_shape.data(), _shape.size(), _strides.data());
else
shape::calcStridesFortran(_shape.data(), _shape.size(), _strides.data());
shapeInfo[2 + _rank * 2] = _ews;
return shapeInfo;
}
}
}
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const int rank) : _dataType(type), _order(order), _rank(rank), _ews(1){
_shape.resize(rank);
_strides.resize(rank);
for (int e = 0; e < rank; e++)
_shape[e] = shape[e];
if (order == 'c')
shape::calcStrides(_shape.data(), _shape.size(), _strides.data());
else
shape::calcStridesFortran(_shape.data(), _shape.size(), _strides.data());
for (auto v:_shape) {
if (v == 0) {
_empty = true;
break;
}
}
}
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape, const Nd4jLong *strides, const int rank, Nd4jLong ews, const bool empty) {
_shape.resize(rank);
_strides.resize(rank);
_dataType = type;
_order = order;
_rank = rank;
_empty = empty;
_ews = ews;
for (int e = 0; e < rank; e++)
_shape[e] = shape[e];
for (int e = 0; e < rank; e++)
_strides[e] = strides[e];
for (auto v:_shape) {
if (v == 0) {
_empty = true;
break;
}
}
}
//////////////////////////////////////////////////////////////////////////
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape): _dataType(type), _order(order), _shape(shape) {
_rank = shape.size();
_ews = 1;
if (_rank > 0) {
_strides.resize(_rank);
for (auto v:_shape) {
if (v == 0) {
@ -150,188 +112,269 @@ ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const st
break;
}
}
}
// no point calculating strides for empty arrays
if (!_empty) {
if (order == 'c')
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
else
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
} else {
// all strides set to 0
memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size());
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const Nd4jLong *shape,
const Nd4jLong *strides, const int rank, Nd4jLong ews, const bool empty) {
_shape.resize(rank);
_strides.resize(rank);
_dataType = type;
_order = order;
_rank = rank;
_empty = empty;
_ews = ews;
for (int e = 0; e < rank; e++)
_shape[e] = shape[e];
for (int e = 0; e < rank; e++)
_strides[e] = strides[e];
for (auto v:_shape) {
if (v == 0) {
_empty = true;
break;
}
}
}
}
//////////////////////////////////////////////////////////////////////////
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::initializer_list<Nd4jLong> &shape): _dataType(type), _order(order), _shape(shape) {
_rank = shape.size();
_ews = 1;
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape)
: _dataType(type), _order(order), _shape(shape) {
_rank = shape.size();
_ews = 1;
_strides.resize(shape.size());
if (order == 'c')
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
else
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
if (_rank > 0) {
_strides.resize(_rank);
for (auto v:_shape) {
if (v == 0) {
_empty = true;
break;
for (auto v:_shape) {
if (v == 0) {
_empty = true;
break;
}
}
// no point calculating strides for empty arrays
if (!_empty) {
if (order == 'c')
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
else
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
} else {
// all strides set to 0
memset(_strides.data(), 0, sizeof(Nd4jLong) * shape.size());
}
}
}
}
//////////////////////////////////////////////////////////////////////////
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape, const std::vector<Nd4jLong> &strides, const Nd4jLong ews): ShapeDescriptor(type, order, shape, strides) {
_ews = ews;
}
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order,
const std::initializer_list<Nd4jLong> &shape) : _dataType(type), _order(order),
_shape(shape) {
_rank = shape.size();
_ews = 1;
ShapeDescriptor::ShapeDescriptor(const DataType type, const Nd4jLong length) : _dataType(type), _ews(1), _order('c'), _rank(1), _empty(false) {
_shape = {length};
_strides = {1};
}
ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) {
_order = shape::order(shapeInfo);
_ews = shape::elementWiseStride(shapeInfo);
_rank = shape::rank(shapeInfo);
if (inheritDtype)
_dataType = ArrayOptions::dataType(shapeInfo);
_empty = shape::isEmpty(shapeInfo);
for (int e = 0; e < _rank; e++) {
_shape.emplace_back(shapeInfo[e + 1]);
if (shapeInfo[e + 1] == 0)
_empty = true;
}
for (int e = 0; e < _rank; e++)
_strides.emplace_back(shapeInfo[e + 1 + _rank]);
}
ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const nd4j::DataType dtypeOverride) : ShapeDescriptor::ShapeDescriptor(shapeInfo, false) {
_dataType = dtypeOverride;
}
ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride) : ShapeDescriptor::ShapeDescriptor(shapeInfo, ArrayOptions::dataType(dtypeOverride)) {
//
}
ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride, const Nd4jLong *orderOverride) : ShapeDescriptor::ShapeDescriptor(shapeInfo, ArrayOptions::dataType(dtypeOverride)) {
_order = shape::order(orderOverride);
}
int ShapeDescriptor::rank() const {
return _rank;
}
Nd4jLong ShapeDescriptor::ews() const {
return _ews;
}
Nd4jLong ShapeDescriptor::arrLength() const {
Nd4jLong len = 1;
for(const auto& dim : const_cast<ShapeDescriptor*>(this)->shape())
len *= dim;
return len;
}
char ShapeDescriptor::order() const {
return _order;
}
DataType ShapeDescriptor::dataType() const {
return _dataType;
}
bool ShapeDescriptor::isEmpty() const {
return _empty;
}
std::vector<Nd4jLong>& ShapeDescriptor::shape() {
return _shape;
}
std::vector<Nd4jLong>& ShapeDescriptor::strides() {
return _strides;
}
ShapeDescriptor::ShapeDescriptor(const ShapeDescriptor &other) {
_rank = other._rank;
_ews = other._ews;
_empty = other._empty;
_dataType = other._dataType;
_order = other._order;
_shape = other._shape;
_strides = other._strides;
}
//////////////////////////////////////////////////////////////////////////
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape, const std::vector<Nd4jLong> &strides): _dataType(type), _order(order), _shape(shape) {
if (strides.empty() && !shape.empty()) {
_strides.resize(shape.size());
if (order == 'c')
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
else
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
}
else {
_strides = strides;
}
for (auto v:_shape) {
if (v == 0) {
_empty = true;
break;
for (auto v:_shape) {
if (v == 0) {
_empty = true;
break;
}
}
}
}
ShapeDescriptor ShapeDescriptor::emptyDescriptor(const DataType type) {
ShapeDescriptor descriptor;
descriptor._dataType = type;
descriptor._empty = true;
descriptor._rank = 0;
descriptor._order = 'c';
descriptor._ews = 1;
return descriptor;
}
ShapeDescriptor ShapeDescriptor::scalarDescriptor(const DataType type) {
ShapeDescriptor descriptor;
descriptor._dataType = type;
descriptor._empty = false;
descriptor._rank = 0;
descriptor._order = 'c';
descriptor._ews = 1;
return descriptor;
}
ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, const DataType type) {
ShapeDescriptor descriptor;
descriptor._dataType = type;
descriptor._shape.emplace_back(length);
if (length > 0)
descriptor._strides.emplace_back(1);
else {
descriptor._strides.emplace_back(0);
descriptor._empty = true;
//////////////////////////////////////////////////////////////////////////
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape,
const std::vector<Nd4jLong> &strides, const Nd4jLong ews) : ShapeDescriptor(type,
order,
shape,
strides) {
_ews = ews;
}
descriptor._order = 'c';
descriptor._ews = 1;
descriptor._rank = 1;
ShapeDescriptor::ShapeDescriptor(const DataType type, const Nd4jLong length) : _dataType(type), _ews(1),
_order('c'), _rank(1),
_empty(false) {
_shape = {length};
_strides = {1};
}
return descriptor;
ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, bool inheritDtype) {
_order = shape::order(shapeInfo);
_ews = shape::elementWiseStride(shapeInfo);
_rank = shape::rank(shapeInfo);
if (inheritDtype)
_dataType = ArrayOptions::dataType(shapeInfo);
_empty = shape::isEmpty(shapeInfo);
for (int e = 0; e < _rank; e++) {
_shape.emplace_back(shapeInfo[e + 1]);
if (shapeInfo[e + 1] == 0)
_empty = true;
}
for (int e = 0; e < _rank; e++)
_strides.emplace_back(shapeInfo[e + 1 + _rank]);
}
ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const nd4j::DataType dtypeOverride)
: ShapeDescriptor::ShapeDescriptor(shapeInfo, false) {
_dataType = dtypeOverride;
}
ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride)
: ShapeDescriptor::ShapeDescriptor(shapeInfo, ArrayOptions::dataType(dtypeOverride)) {
//
}
ShapeDescriptor::ShapeDescriptor(const Nd4jLong *shapeInfo, const Nd4jLong *dtypeOverride,
const Nd4jLong *orderOverride) : ShapeDescriptor::ShapeDescriptor(shapeInfo,
ArrayOptions::dataType(
dtypeOverride)) {
_order = shape::order(orderOverride);
}
int ShapeDescriptor::rank() const {
return _rank;
}
Nd4jLong ShapeDescriptor::ews() const {
return _ews;
}
Nd4jLong ShapeDescriptor::arrLength() const {
Nd4jLong len = 1;
for (const auto &dim : const_cast<ShapeDescriptor *>(this)->shape())
len *= dim;
return len;
}
char ShapeDescriptor::order() const {
return _order;
}
DataType ShapeDescriptor::dataType() const {
return _dataType;
}
bool ShapeDescriptor::isEmpty() const {
return _empty;
}
std::vector<Nd4jLong> &ShapeDescriptor::shape() {
return _shape;
}
std::vector<Nd4jLong> &ShapeDescriptor::strides() {
return _strides;
}
ShapeDescriptor::ShapeDescriptor(const ShapeDescriptor &other) {
_rank = other._rank;
_ews = other._ews;
_empty = other._empty;
_dataType = other._dataType;
_order = other._order;
_shape = other._shape;
_strides = other._strides;
}
//////////////////////////////////////////////////////////////////////////
ShapeDescriptor::ShapeDescriptor(const DataType type, const char order, const std::vector<Nd4jLong> &shape,
const std::vector<Nd4jLong> &strides) : _dataType(type), _order(order),
_shape(shape) {
if (strides.empty() && !shape.empty()) {
_strides.resize(shape.size());
if (order == 'c')
shape::calcStrides(_shape.data(), shape.size(), _strides.data());
else
shape::calcStridesFortran(_shape.data(), shape.size(), _strides.data());
} else {
_strides = strides;
}
for (auto v:_shape) {
if (v == 0) {
_empty = true;
break;
}
}
}
ShapeDescriptor ShapeDescriptor::emptyDescriptor(const DataType type) {
ShapeDescriptor descriptor;
descriptor._dataType = type;
descriptor._empty = true;
descriptor._rank = 0;
descriptor._order = 'c';
descriptor._ews = 1;
return descriptor;
}
ShapeDescriptor ShapeDescriptor::scalarDescriptor(const DataType type) {
ShapeDescriptor descriptor;
descriptor._dataType = type;
descriptor._empty = false;
descriptor._rank = 0;
descriptor._order = 'c';
descriptor._ews = 1;
return descriptor;
}
ShapeDescriptor ShapeDescriptor::vectorDescriptor(const Nd4jLong length, const DataType type) {
ShapeDescriptor descriptor;
descriptor._dataType = type;
descriptor._shape.emplace_back(length);
if (length > 0)
descriptor._strides.emplace_back(1);
else {
descriptor._strides.emplace_back(0);
descriptor._empty = true;
}
descriptor._order = 'c';
descriptor._ews = 1;
descriptor._rank = 1;
return descriptor;
}
}
namespace std {
size_t hash<nd4j::ShapeDescriptor>::operator()(const nd4j::ShapeDescriptor &k) const {
auto res = std::hash<Nd4jLong>()(k.arrLength());
res ^= std::hash<char>()(k.order()) + 0x9e3779b9 + (res << 6) + (res >> 2);
res ^= k.dataType() + 0x9e3779b9 + (res << 6) + (res >> 2);
res ^= std::hash<int>()(k.rank()) + 0x9e3779b9 + (res << 6) + (res >> 2);
res ^= std::hash<Nd4jLong>()(k.ews()) + 0x9e3779b9 + (res << 6) + (res >> 2);
auto shapes = const_cast<nd4j::ShapeDescriptor&>(k).shape();
auto strides = const_cast<nd4j::ShapeDescriptor&>(k).strides();
for (auto s: shapes) {
res ^= std::hash<Nd4jLong>()(s) + 0x9e3779b9 + (res << 6) + (res >> 2);
}
for (auto s: strides) {
res ^= std::hash<Nd4jLong>()(s) + 0x9e3779b9 + (res << 6) + (res >> 2);
}
return res;
}
}

View File

@ -65,7 +65,11 @@ namespace nd4j {
return _axis;
}
ShapeDescriptor& TadDescriptor::originalShape() {
ShapeDescriptor& TadDescriptor::originalShape(){
return _originalShape;
}
ShapeDescriptor const& TadDescriptor::originalShapeConst() const{
return _originalShape;
}
@ -73,3 +77,18 @@ namespace nd4j {
return _unitiesInShape;
}
}
namespace std {
size_t hash<nd4j::TadDescriptor>::operator()(const nd4j::TadDescriptor &k) const {
// Compute individual hash values for first,
// second and third and combine them using XOR
// and bit shifting:
auto res = std::hash<int>()((int)k.areUnitiesinShape());
res ^= std::hash<nd4j::ShapeDescriptor>()(k.originalShapeConst()) + 0x9e3779b9 + (res << 6) + (res >> 2);
auto axes = const_cast<nd4j::TadDescriptor&>(k).axis();
for (auto a: axes) {
res ^= std::hash<int>()(a) + 0x9e3779b9 + (res << 6) + (res >> 2);
}
return res;
}
}

View File

@ -38,7 +38,7 @@
#include <cstdio>
#include <string>
#include <algorithm>
#include <map>
#include <unordered_map>
#include <assert.h>
#include <iostream>
#include <sstream>
@ -69,7 +69,7 @@ namespace cnpy {
}
};
struct ND4J_EXPORT npz_t : public std::map<std::string, NpyArray> {
struct ND4J_EXPORT npz_t : public std::unordered_map<std::string, NpyArray> {
void destruct() {
npz_t::iterator it = this->begin();
for(; it != this->end(); ++it) (*it).second.destruct();

View File

@ -24,6 +24,7 @@
#include <vector>
#include <initializer_list>
#include <unordered_map>
#include <map>
#include <string>
#include <flatbuffers/flatbuffers.h>
#include <graph/Variable.h>
@ -33,8 +34,8 @@ namespace nd4j {
class ExecutionResult {
private:
std::vector<Variable *> _variables;
std::map<std::string, Variable *> _stringIdMap;
std::map<std::pair<int, int>, Variable *> _pairIdMap;
MAP_IMPL<std::string, Variable *> _stringIdMap;
MAP_IMPL<std::pair<int, int>, Variable *> _pairIdMap;
// this flag is used to optionally release variables
bool _releasable = false;

View File

@ -21,6 +21,8 @@
#ifndef LIBND4J_FLOWPATH_H
#define LIBND4J_FLOWPATH_H
#include <op_boilerplate.h>
#include <unordered_map>
#include <map>
#include <pointercast.h>
#include <graph/NodeState.h>
@ -32,8 +34,8 @@ namespace nd4j {
namespace graph {
class ND4J_EXPORT FlowPath {
private:
std::map<int, NodeState> _states;
std::map<Nd4jLong, FrameState> _frames;
MAP_IMPL<int, NodeState> _states;
MAP_IMPL<Nd4jLong, FrameState> _frames;
void ensureNode(int nodeId);
void ensureFrame(int nodeId);

View File

@ -24,6 +24,7 @@
#include <list>
#include <algorithm>
#include <unordered_map>
#include <map>
//#include <NDArray.h>
#include <graph/Node.h>
#include <graph/Stash.h>
@ -50,10 +51,10 @@ namespace nd4j {
// vector holds ID's of top nodes only
std::vector<int > *_nodes;
std::map<int, nd4j::graph::Node*> *_mapped;
MAP_IMPL<int, nd4j::graph::Node*> *_mapped;
std::map<int, std::vector<nd4j::graph::Node*> *> *_onion;
std::map<int, nd4j::graph::Node*> _unmapped;
MAP_IMPL<int, std::vector<nd4j::graph::Node*> *> *_onion;
MAP_IMPL<int, nd4j::graph::Node*> _unmapped;
std::vector<int> _unmappedMap; // macOS?
std::mutex _mutexPreprocessing;
@ -63,7 +64,7 @@ namespace nd4j {
std::vector<int> _autos;
std::map<int, Scope*> _mappedScopes;
MAP_IMPL<int, Scope*> _mappedScopes;
std::vector<Scope*> _scopes;
////////////////////////////////////////
@ -124,13 +125,13 @@ namespace nd4j {
*
* @return
*/
std::map<int, std::vector<nd4j::graph::Node*> *> *getOnion();
MAP_IMPL<int, std::vector<nd4j::graph::Node*> *> *getOnion();
/**
* This method returns map of all nodes of the graph
* @return
*/
std::map<int, nd4j::graph::Node*> *getMapped();
MAP_IMPL<int, nd4j::graph::Node*>* getMapped();
/**
* This method returns outputs of this graph
@ -233,7 +234,7 @@ namespace nd4j {
return &_output;
}
FORCEINLINE std::map<int, Scope*>* scopes() {
FORCEINLINE MAP_IMPL<int, Scope*>* scopes() {
return &_mappedScopes;
}

View File

@ -21,6 +21,7 @@
#include <helpers/logger.h>
#include <pointercast.h>
#include <unordered_map>
#include <map>
#include <graph/Graph.h>
#include <helpers/SimpleReadWriteLock.h>
#include <exceptions/unknown_graph_exception.h>
@ -30,9 +31,9 @@ namespace nd4j {
class ND4J_EXPORT GraphHolder {
private:
static GraphHolder *_INSTANCE;
std::map<Nd4jLong, Graph *> _graphF;
MAP_IMPL<Nd4jLong, Graph *> _graphF;
std::map<Nd4jLong, SimpleReadWriteLock> _locks;
MAP_IMPL<Nd4jLong, SimpleReadWriteLock> _locks;
GraphHolder() = default;
~GraphHolder() = default;

View File

@ -26,6 +26,7 @@
#include <dll.h>
#include <vector>
#include <unordered_map>
#include <map>
#include <graph/Scope.h>
#include <Status.h>
#include <graph/VariableSpace.h>
@ -43,7 +44,7 @@ namespace graph {
Nd4jLong _id = 0;
// map of scopes. Scope id is used as key, since it's referred in calls later anyway
std::map<int, Scope *> _scopes;
MAP_IMPL<int, Scope *> _scopes;
// this variable space holds temp references
VariableSpace _variableSpace;

View File

@ -22,6 +22,8 @@
#define LIBND4J_SESSIONLOCALSTORAGE_H
#include <thread>
#include <unordered_map>
#include <map>
#include "VariableSpace.h"
#include "Context.h"
#include "Stash.h"
@ -32,8 +34,8 @@ namespace nd4j{
class ND4J_EXPORT SessionLocalStorage {
protected:
std::atomic<Nd4jLong> _sessionCounter;
std::map<Nd4jLong, Nd4jLong> _threadSession;
std::map<Nd4jLong, VariableSpace*> _threadVariableSpace;
MAP_IMPL<Nd4jLong, Nd4jLong> _threadSession;
MAP_IMPL<Nd4jLong, VariableSpace*> _threadVariableSpace;
VariableSpace* _variableSpace;
Stash* _stash;

View File

@ -23,9 +23,11 @@
//#include <graph/Block.h>
#include <NDArray.h>
#include <unordered_map>
#include <map>
#include <vector>
#include <string>
#include <atomic>
#include <functional>
#include <pointercast.h>
namespace nd4j {
@ -34,11 +36,34 @@ namespace nd4j {
int _node;
std::string _name;
public:
KeyPair(int node = 0, const char * name = nullptr);
KeyPair(int node = 0, const char *name = nullptr);
bool operator<(const KeyPair& other) const;
bool operator<(const KeyPair &other) const;
bool operator==(const KeyPair &other) const {
return _node == other._node;
}
int key() const { return _node; }
std::string name() const { return _name; }
};
}
}
#ifndef __JAVACPP_HACK__
namespace std {
template <>
class ND4J_EXPORT hash<nd4j::graph::KeyPair> {
public:
size_t operator()(const nd4j::graph::KeyPair& k) const;
};
};
#endif
namespace nd4j {
namespace graph {
class ND4J_EXPORT Stash {
protected:
std::map<nd4j::graph::KeyPair, nd4j::NDArray*> _stash;
@ -60,6 +85,7 @@ namespace nd4j {
void clear();
};
}
}

View File

@ -29,6 +29,31 @@
#include <graph/generated/node_generated.h>
#include <graph/generated/graph_generated.h>
#ifndef __JAVACPP_HACK__
namespace std {
template <>
class ND4J_EXPORT hash<std::pair<int, int>> {
public:
size_t operator()(const std::pair<int,int>& k) const;
};
template <>
class ND4J_EXPORT hash<bfloat16> {
public:
size_t operator()(const bfloat16& k) const;
};
template <>
class ND4J_EXPORT hash<float16> {
public:
size_t operator()(const float16& k) const;
};
};
#endif
namespace nd4j {
namespace graph {
class ND4J_EXPORT Variable {

View File

@ -51,7 +51,7 @@ namespace nd4j {
Nd4jLong lastStep = 0L;
std::vector<Nd4jLong *> shapes;
std::map<std::pair<int,int>, Nd4jLong*> shapesMap;
MAP_IMPL<std::pair<int,int>, Nd4jLong*> shapesMap;
int cntFD = 0;
@ -249,11 +249,11 @@ namespace nd4j {
return res;
}
std::map<int, Node *> * Graph::getMapped() {
MAP_IMPL<int, Node *> * Graph::getMapped() {
return _mapped;
}
std::map<int, std::vector<Node *> *>* Graph::getOnion() {
MAP_IMPL<int, std::vector<Node *> *>* Graph::getOnion() {
return _onion;
}
@ -518,7 +518,7 @@ namespace nd4j {
return ND4J_STATUS_OK;
}
typename std::map<int, Node *>::iterator fit;
typename MAP_IMPL<int, Node *>::iterator fit;
int cnts = 0;
for ( fit = _unmapped.begin(); fit != _unmapped.end(); fit++ ) {
int tK = fit->first;
@ -535,7 +535,7 @@ namespace nd4j {
std::vector<int> queue;
// first pass for unmapped nodes, we try to build tale here
typename std::map<int, Node *>::iterator it;
typename MAP_IMPL<int, Node *>::iterator it;
int cntf = 0;
nd4j_debug("-----------\n","");
for ( it = _unmapped.begin(); it != _unmapped.end(); it++ ) {
@ -866,8 +866,8 @@ namespace nd4j {
}
Graph::Graph(const FlatGraph *flatGraph, VariableSpace *variableSpace) {
this->_onion = new std::map<int, std::vector<Node *> *>();
this->_mapped = new std::map<int, Node *> ();
this->_onion = new MAP_IMPL<int, std::vector<Node *> *>();
this->_mapped = new MAP_IMPL<int, Node *> ();
this->_nodes = new std::vector<int>();
this->_variableSpace = variableSpace == nullptr ? new VariableSpace() : variableSpace;
bool trusted = flatGraph != nullptr;

View File

@ -20,6 +20,16 @@
#include <graph/Stash.h>
namespace std {
size_t hash<nd4j::graph::KeyPair>::operator()(const nd4j::graph::KeyPair& k) const {
using std::hash;
auto res = std::hash<std::string>()(k.name());
res ^= std::hash<int>()(k.key()) + 0x9e3779b9 + (res << 6) + (res >> 2);
return res;
}
}
namespace nd4j {
namespace graph {
nd4j::graph::KeyPair::KeyPair(int node, const char * name) {

View File

@ -341,3 +341,20 @@ namespace nd4j {
}
}
}
namespace std {
size_t hash<std::pair<int, int>>::operator()(const std::pair<int,int>& k) const {
auto v = std::hash<int>()(k.first);
v ^= std::hash<int>()(k.second) + 0x9e3779b9 + (v << 6) + (v >> 2);
return v;
}
size_t hash<bfloat16>::operator()(const bfloat16& k) const {
return std::hash<float>()((float)k);
}
size_t hash<float16>::operator()(const float16& k) const {
return std::hash<float>()((float)k);
}
}

View File

@ -38,7 +38,7 @@ namespace nd4j {
static ConstantHelper* _INSTANCE;
ConstantHelper();
std::vector<std::map<ConstantDescriptor, ConstantHolder*>> _cache;
std::vector<MAP_IMPL<ConstantDescriptor, ConstantHolder*>> _cache;
// tracking of per-device constant memory buffers (CUDA only atm)
std::vector<Nd4jPointer> _devicePointers;

View File

@ -38,7 +38,7 @@ namespace nd4j {
static ConstantShapeHelper *_INSTANCE;
std::mutex _mutex;
std::vector<std::map<ShapeDescriptor, ConstantDataBuffer>> _cache;
std::vector<MAP_IMPL<ShapeDescriptor, ConstantDataBuffer>> _cache;
ConstantShapeHelper();

View File

@ -38,7 +38,7 @@ namespace nd4j {
static ConstantTadHelper *_INSTANCE;
std::mutex _mutex;
std::vector<std::map<TadDescriptor, TadPack>> _cache;
std::vector<MAP_IMPL<TadDescriptor, TadPack>> _cache;
ConstantTadHelper();
public:

View File

@ -29,12 +29,13 @@
#include <cstring>
namespace nd4j {
ConstantHelper::ConstantHelper() {
int numDevices = getNumberOfDevices();
_cache.resize(numDevices);
_counters.resize(numDevices);
for (int e = 0; e < numDevices; e++) {
std::map<ConstantDescriptor, ConstantHolder*> map;
MAP_IMPL<ConstantDescriptor, ConstantHolder*> map;
_cache[e] = map;
_counters[e] = 0L;

View File

@ -29,7 +29,7 @@ namespace nd4j {
ConstantShapeHelper::ConstantShapeHelper() {
_cache.resize(32);
for (int e = 0; e < 32; e++) {
std::map<ShapeDescriptor, ConstantDataBuffer> cache;
MAP_IMPL<ShapeDescriptor, ConstantDataBuffer> cache;
_cache[e] = cache;
}
}

View File

@ -24,10 +24,11 @@
#ifndef __CUDABLAS__
namespace nd4j {
ConstantTadHelper::ConstantTadHelper() {
std::map<TadDescriptor, TadPack> pack;
MAP_IMPL<TadDescriptor, TadPack> pack;
_cache.emplace_back(pack);
}

View File

@ -70,7 +70,7 @@ namespace nd4j {
throw cuda_exception::build("cudaSetDevice failed", res);
auto constant = getConstantSpace();
std::map<ConstantDescriptor, ConstantHolder*> devCache;
MAP_IMPL<ConstantDescriptor, ConstantHolder*> devCache;
_devicePointers[e] = constant;
_deviceOffsets[e] = 0;

View File

@ -32,7 +32,7 @@ namespace nd4j {
_cache.resize(numDevices);
for (int e = 0; e < numDevices; e++) {
std::map<ShapeDescriptor, ConstantDataBuffer> cache;
MAP_IMPL<ShapeDescriptor, ConstantDataBuffer> cache;
_cache[e] = cache;
}
}

View File

@ -31,7 +31,7 @@ namespace nd4j {
auto numDevices = AffinityManager::numberOfDevices();
for (int e = 0; e < numDevices; e++) {
std::map<TadDescriptor, TadPack> pack;
MAP_IMPL<TadDescriptor, TadPack> pack;
_cache.emplace_back(pack);
}
}

View File

@ -22,6 +22,8 @@
#define LIBND4J_MEMORYREGISTRATOR_H
#include "Workspace.h"
#include <op_boilerplate.h>
#include <unordered_map>
#include <map>
#include <mutex>
#include <dll.h>
@ -32,7 +34,7 @@ namespace nd4j {
protected:
static MemoryRegistrator* _INSTANCE;
Workspace* _workspace;
std::map<Nd4jLong, Nd4jLong> _footprint;
MAP_IMPL<Nd4jLong, Nd4jLong> _footprint;
std::mutex _lock;
MemoryRegistrator();

View File

@ -23,7 +23,6 @@
#include <string>
#include <vector>
#include <map>
#include <initializer_list>
#include <helpers/helper_hash.h>
#include <ops/InputType.h>
@ -84,8 +83,8 @@ namespace nd4j {
std::vector<nd4j::DataType> _allowedOuts;
// optional per-input configuration
std::map<int, std::vector<nd4j::DataType>> _outputTypes;
std::map<int, std::vector<nd4j::DataType>> _inputTypes;
MAP_IMPL<int, std::vector<nd4j::DataType>> _outputTypes;
MAP_IMPL<int, std::vector<nd4j::DataType>> _inputTypes;
// field for ops that allow data type override at runtime

View File

@ -33,6 +33,26 @@
#include <cstdlib>
#include <csignal>
#ifndef __JAVACPP_HACK__
namespace std {
template <>
class hash<std::pair<Nd4jLong, samediff::Engine>> {
public:
size_t operator()(const std::pair<Nd4jLong, samediff::Engine>& k) const;
};
template <>
class hash<std::pair<std::string, samediff::Engine>> {
public:
size_t operator()(const std::pair<std::string, samediff::Engine>& k) const;
};
};
#endif
namespace nd4j {
namespace ops {
/**
@ -59,16 +79,16 @@ namespace nd4j {
#endif
};
std::map<Nd4jLong, std::string> _msvc;
MAP_IMPL<Nd4jLong, std::string> _msvc;
// pointers to our operations
std::map<Nd4jLong, nd4j::ops::DeclarableOp*> _declarablesLD;
std::map<std::string, nd4j::ops::DeclarableOp*> _declarablesD;
MAP_IMPL<Nd4jLong, nd4j::ops::DeclarableOp*> _declarablesLD;
MAP_IMPL<std::string, nd4j::ops::DeclarableOp*> _declarablesD;
std::vector<nd4j::ops::DeclarableOp *> _uniqueD;
// pointers to platform-specific helpers
std::map<std::pair<Nd4jLong, samediff::Engine>, nd4j::ops::platforms::PlatformHelper*> _helpersLH;
std::map<std::pair<std::string, samediff::Engine>, nd4j::ops::platforms::PlatformHelper*> _helpersH;
MAP_IMPL<std::pair<Nd4jLong, samediff::Engine>, nd4j::ops::platforms::PlatformHelper*> _helpersLH;
MAP_IMPL<std::pair<std::string, samediff::Engine>, nd4j::ops::platforms::PlatformHelper*> _helpersH;
std::vector<nd4j::ops::platforms::PlatformHelper*> _uniqueH;
std::mutex _locker;

View File

@ -350,7 +350,7 @@ namespace helpers {
// if input is a vector: (as if in doc sample)
//int idx = static_cast<int>((*indices)(0.));
std::map<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
MAP_IMPL<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
for (Nd4jLong e = 0; e < indices->lengthOf(); ++e)
idxs[indices->e<Nd4jLong>(e)].push_back(e);
@ -400,7 +400,7 @@ namespace helpers {
static void unsortedSegmentMinFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
// if input is a vector: (as if in doc sample)
//int idx = static_cast<int>((*indices)(0.));
std::map<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
MAP_IMPL<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
for (Nd4jLong e = 0; e < indices->lengthOf(); ++e)
idxs[indices->e<Nd4jLong>(e)].push_back(e);
@ -452,7 +452,7 @@ namespace helpers {
BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES);
void unsortedSegmentMeanFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
std::map<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
MAP_IMPL<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
for (Nd4jLong e = 0; e < indices->lengthOf(); ++e)
idxs[indices->e<Nd4jLong>(e)].push_back(e);
@ -494,7 +494,7 @@ namespace helpers {
}
void unsortedSegmentSumFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
std::map<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
MAP_IMPL<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
for (Nd4jLong e = 0; e < indices->lengthOf(); ++e)
idxs[indices->e<Nd4jLong>(e)].push_back(e);
@ -534,7 +534,7 @@ namespace helpers {
template <typename T>
void unsortedSegmentProdFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
std::map<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
MAP_IMPL<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
for (Nd4jLong e = 0; e < indices->lengthOf(); ++e)
idxs[indices->e<Nd4jLong>(e)].push_back(e);
@ -575,7 +575,7 @@ namespace helpers {
BUILD_SINGLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES);
void unsortedSegmentSqrtNFunctor(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) {
std::map<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
MAP_IMPL<Nd4jLong, std::vector<Nd4jLong>> idxs;//(indices->lengthOf());
for (Nd4jLong e = 0; e < indices->lengthOf(); ++e)
idxs[indices->e<Nd4jLong>(e)].push_back(e);
@ -719,7 +719,7 @@ namespace helpers {
// segmen mean
int segmentMeanFunctorBP(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) {
int numClasses = output->sizeAt(0);
std::map<Nd4jLong, Nd4jLong> classCount;//(numClasses);
MAP_IMPL<Nd4jLong, Nd4jLong> classCount;//(numClasses);
for (Nd4jLong count = 0; count < numClasses; ++count) {
classCount[count] = 0;
@ -931,7 +931,7 @@ namespace helpers {
int unsortedSegmentMeanFunctorBP(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
std::map<Nd4jLong, Nd4jLong> classCount;//(numClasses);
MAP_IMPL<Nd4jLong, Nd4jLong> classCount;//(numClasses);
for (Nd4jLong count = 0; count < numOfClasses; ++count) {
classCount[count] = 0;
@ -1040,7 +1040,7 @@ namespace helpers {
// template <typename T>
int unsortedSegmentSqrtNFunctorBP(nd4j::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) {
std::map<Nd4jLong, Nd4jLong> classCount;//(numClasses);
MAP_IMPL<Nd4jLong, Nd4jLong> classCount;//(numClasses);
for (Nd4jLong count = 0; count < numOfClasses; ++count) {
classCount[count] = 0;

View File

@ -21,6 +21,7 @@
#include <ops/declarable/helpers/unique.h>
#include <Status.h>
#include <execution/Threads.h>
#include <graph/Variable.h>
namespace nd4j {
namespace ops {
@ -48,13 +49,12 @@ namespace helpers {
BUILD_SINGLE_TEMPLATE(template Nd4jLong uniqueCount_, (NDArray* input), LIBND4J_TYPES);
template <typename T>
static Nd4jStatus uniqueFunctor_(NDArray* input, NDArray* values, NDArray* indices, NDArray* counts) {
std::vector<T> valuesVector;
std::map<T, int> indicesMap;
std::map<T, int> countsMap;
MAP_IMPL<T, int> indicesMap;
MAP_IMPL<T, int> countsMap;
for (int e = 0; e < input->lengthOf(); e++) {
T v = input->e<T>(e);

View File

@ -130,7 +130,7 @@ namespace nd4j {
_locker.lock();
if (!isInit) {
for (std::map<std::string, nd4j::ops::DeclarableOp*>::iterator it=_declarablesD.begin(); it!=_declarablesD.end(); ++it) {
for (MAP_IMPL<std::string, nd4j::ops::DeclarableOp*>::iterator it=_declarablesD.begin(); it!=_declarablesD.end(); ++it) {
std::string op = it->first + ":"
+ local_to_string(it->second->getOpDescriptor()->getHash()) + ":"
+ local_to_string(it->second->getOpDescriptor()->getNumberOfInputs()) + ":"
@ -261,3 +261,19 @@ namespace nd4j {
}
}
namespace std {
size_t hash<std::pair<Nd4jLong, samediff::Engine>>::operator()(const std::pair<Nd4jLong, samediff::Engine>& k) const {
using std::hash;
auto res = std::hash<Nd4jLong>()(k.first);
res ^= std::hash<int>()((int) k.second) + 0x9e3779b9 + (res << 6) + (res >> 2);
return res;
}
size_t hash<std::pair<std::string, samediff::Engine>>::operator()(const std::pair<std::string, samediff::Engine>& k) const {
using std::hash;
auto res = std::hash<std::string>()(k.first);
res ^= std::hash<int>()((int) k.second) + 0x9e3779b9 + (res << 6) + (res >> 2);
return res;
}
}

View File

@ -60,5 +60,28 @@ typedef int Nd4jStatus;
#define ND4J_STATUS_MAYBE 119
#ifdef _MSC_VER
#include <map>
#define MAP_IMPL std::map
#elif __clang__
#include <unordered_map>
#define MAP_IMPL std::unordered_map
#elif __GNUC__
#include <unordered_map>
#define MAP_IMPL std::unordered_map
#else
#include <unordered_map>
#define MAP_IMPL std::unordered_map
#endif
#endif //NATIVEOPERATIONS_POINTERCAST_H

View File

@ -104,6 +104,25 @@ TEST_F(ConstantShapeHelperTests, basic_test_1) {
delete []ptr;
}
TEST_F(ConstantShapeHelperTests, stress_test_1) {
for (auto x = 0; x < 1000; x++) {
auto ptr = ShapeBuilders::createShapeInfo(nd4j::DataType::FLOAT32, 'c', {5, x + 10, x + 1});
ShapeDescriptor descriptor(ptr);
ConstantShapeHelper::getInstance()->createShapeInfo(descriptor);
delete [] ptr;
}
ShapeDescriptor aShape(nd4j::DataType::FLOAT32, 'c', {(Nd4jLong)5, (Nd4jLong)382, (Nd4jLong)373});
// nd4j_printf("%d\n", ConstantShapeHelper::getInstance()->cachedEntriesForDevice(0));
auto timeStart = std::chrono::system_clock::now();
ASSERT_TRUE(ConstantShapeHelper::getInstance()->checkBufferExistenceForShapeInfo(aShape));
auto timeEnd = std::chrono::system_clock::now();
auto outerTime = std::chrono::duration_cast<std::chrono::nanoseconds>(timeEnd - timeStart).count();
nd4j_printf("Total time (us) %lld\n", outerTime);
}
TEST_F(ConstantShapeHelperTests, basic_test_3) {
auto array = NDArrayFactory::create_<float>('c', {128});

View File

@ -71,17 +71,17 @@ TEST_F(StashTests, BasicTests_2) {
auto cappa = NDArrayFactory::create_<float>('c',{5, 5});
cappa->assign(3.0);
stash.storeArray(1, "alpha1", alpha);
stash.storeArray(1, "alpha2", beta);
stash.storeArray(1, "alpha3", cappa);
stash.storeArray(1, "alpha", alpha);
stash.storeArray(1, "beta", beta);
stash.storeArray(1, "cappa", cappa);
ASSERT_FALSE(stash.checkStash(2, "alpha1"));
ASSERT_FALSE(stash.checkStash(2, "alpha2"));
ASSERT_FALSE(stash.checkStash(2, "alpha3"));
ASSERT_FALSE(stash.checkStash(2, "alpha"));
ASSERT_FALSE(stash.checkStash(2, "beta"));
ASSERT_FALSE(stash.checkStash(2, "cappa"));
ASSERT_TRUE(alpha == stash.extractArray(1, "alpha1"));
ASSERT_TRUE(beta == stash.extractArray(1, "alpha2"));
ASSERT_TRUE(cappa == stash.extractArray(1, "alpha3"));
ASSERT_TRUE(alpha == stash.extractArray(1, "alpha"));
ASSERT_TRUE(beta == stash.extractArray(1, "beta"));
ASSERT_TRUE(cappa == stash.extractArray(1, "cappa"));
}