parent
2447af0953
commit
b75bac750d
|
@ -261,19 +261,17 @@ namespace nd4j {
|
||||||
void syncToDevice() const;
|
void syncToDevice() const;
|
||||||
void syncShape() const;
|
void syncShape() const;
|
||||||
|
|
||||||
#if !defined(__JAVACPP_HACK__) && !defined(_JNI_IMPLEMENTATION_)
|
|
||||||
/**
|
/**
|
||||||
* This method can be used on architectures that use special buffers
|
* This method can be used on architectures that use special buffers
|
||||||
* @param writeList
|
* @param writeList
|
||||||
* @param readList
|
* @param readList
|
||||||
*/
|
*/
|
||||||
static void registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
|
static void registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
|
||||||
static void prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
|
static void prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
|
||||||
|
|
||||||
static void registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
|
static void registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
|
||||||
static void preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
|
static void preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
||||||
|
|
|
@ -182,22 +182,19 @@ void NDArray::synchronize(const char* msg) const {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
#if !defined(__JAVACPP_HACK__) && !defined(_JNI_IMPLEMENTATION_)
|
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
void NDArray::syncShape() const {
|
void NDArray::syncShape() const {
|
||||||
// no-op
|
// no-op
|
||||||
}
|
}
|
||||||
|
|
|
@ -230,10 +230,9 @@ void NDArray::synchronize(const char* msg) const {
|
||||||
if (res != 0)
|
if (res != 0)
|
||||||
throw std::runtime_error(msg + std::string(": synchronization failed !"));
|
throw std::runtime_error(msg + std::string(": synchronization failed !"));
|
||||||
}
|
}
|
||||||
#if !defined(__JAVACPP_HACK__) && !defined(_JNI_IMPLEMENTATION_)
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
|
|
||||||
for (const auto& a : readList)
|
for (const auto& a : readList)
|
||||||
if(a != nullptr)
|
if(a != nullptr)
|
||||||
|
@ -249,7 +248,7 @@ void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, co
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
||||||
|
|
||||||
for (const auto& p : readList)
|
for (const auto& p : readList)
|
||||||
if(p != nullptr)
|
if(p != nullptr)
|
||||||
|
@ -261,7 +260,7 @@ void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, c
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
||||||
|
|
||||||
for (const auto& a : readList)
|
for (const auto& a : readList)
|
||||||
if(a != nullptr)
|
if(a != nullptr)
|
||||||
|
@ -277,7 +276,7 @@ void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, co
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
||||||
|
|
||||||
for (const auto& p : readList)
|
for (const auto& p : readList)
|
||||||
if(p != nullptr)
|
if(p != nullptr)
|
||||||
|
@ -288,8 +287,6 @@ void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, c
|
||||||
p->tickWriteHost();
|
p->tickWriteHost();
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::syncShape() const {
|
void NDArray::syncShape() const {
|
||||||
cudaMemcpy(getSpecialShapeInfo(), getShapeInfo(), shape::shapeInfoByteLength(getShapeInfo()), cudaMemcpyHostToDevice);
|
cudaMemcpy(getSpecialShapeInfo(), getShapeInfo(), shape::shapeInfoByteLength(getShapeInfo()), cudaMemcpyHostToDevice);
|
||||||
|
|
Loading…
Reference in New Issue