parent
2447af0953
commit
b75bac750d
|
@ -261,19 +261,17 @@ namespace nd4j {
|
|||
void syncToDevice() const;
|
||||
void syncShape() const;
|
||||
|
||||
#if !defined(__JAVACPP_HACK__) && !defined(_JNI_IMPLEMENTATION_)
|
||||
/**
|
||||
* This method can be used on architectures that use special buffers
|
||||
* @param writeList
|
||||
* @param readList
|
||||
*/
|
||||
static void registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
|
||||
static void prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
|
||||
static void registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
|
||||
static void prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
|
||||
|
||||
static void registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList);
|
||||
static void preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables = false);
|
||||
static void registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList);
|
||||
static void preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables = false);
|
||||
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This method returns buffer pointer offset by given number of elements, wrt own data type
|
||||
|
|
|
@ -182,22 +182,19 @@ void NDArray::synchronize(const char* msg) const {
|
|||
// no-op
|
||||
}
|
||||
|
||||
#if !defined(__JAVACPP_HACK__) && !defined(_JNI_IMPLEMENTATION_)
|
||||
void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
||||
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
||||
// no-op
|
||||
}
|
||||
void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
||||
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
||||
// no-op
|
||||
}
|
||||
void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
||||
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
||||
// no-op
|
||||
}
|
||||
void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
||||
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
void NDArray::syncShape() const {
|
||||
// no-op
|
||||
}
|
||||
|
|
|
@ -230,10 +230,9 @@ void NDArray::synchronize(const char* msg) const {
|
|||
if (res != 0)
|
||||
throw std::runtime_error(msg + std::string(": synchronization failed !"));
|
||||
}
|
||||
#if !defined(__JAVACPP_HACK__) && !defined(_JNI_IMPLEMENTATION_)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
||||
void NDArray::prepareSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
||||
|
||||
for (const auto& a : readList)
|
||||
if(a != nullptr)
|
||||
|
@ -249,7 +248,7 @@ void NDArray::prepareSpecialUse(const std::vector<const NDArray*>& writeList, co
|
|||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
||||
void NDArray::registerSpecialUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
||||
|
||||
for (const auto& p : readList)
|
||||
if(p != nullptr)
|
||||
|
@ -261,7 +260,7 @@ void NDArray::registerSpecialUse(const std::vector<const NDArray*>& writeList, c
|
|||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList, bool synchronizeWritables) {
|
||||
void NDArray::preparePrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList, bool synchronizeWritables) {
|
||||
|
||||
for (const auto& a : readList)
|
||||
if(a != nullptr)
|
||||
|
@ -277,7 +276,7 @@ void NDArray::preparePrimaryUse(const std::vector<const NDArray*>& writeList, co
|
|||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, const std::vector<const NDArray*>& readList) {
|
||||
void NDArray::registerPrimaryUse(const std::initializer_list<const NDArray*>& writeList, const std::initializer_list<const NDArray*>& readList) {
|
||||
|
||||
for (const auto& p : readList)
|
||||
if(p != nullptr)
|
||||
|
@ -288,8 +287,6 @@ void NDArray::registerPrimaryUse(const std::vector<const NDArray*>& writeList, c
|
|||
p->tickWriteHost();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void NDArray::syncShape() const {
|
||||
cudaMemcpy(getSpecialShapeInfo(), getShapeInfo(), shape::shapeInfoByteLength(getShapeInfo()), cudaMemcpyHostToDevice);
|
||||
|
|
Loading…
Reference in New Issue