diff --git a/libnd4j/blas/NDArray.h b/libnd4j/blas/NDArray.h index ed1962279..2f035f31b 100644 --- a/libnd4j/blas/NDArray.h +++ b/libnd4j/blas/NDArray.h @@ -261,19 +261,17 @@ namespace nd4j { void syncToDevice() const; void syncShape() const; -#if !defined(__JAVACPP_HACK__) && !defined(_JNI_IMPLEMENTATION_) /** * This method can be used on architectures that use special buffers * @param writeList * @param readList */ - static void registerSpecialUse(const std::vector& writeList, const std::vector& readList); - static void prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); + static void registerSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList); + static void prepareSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables = false); - static void registerPrimaryUse(const std::vector& writeList, const std::vector& readList); - static void preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables = false); + static void registerPrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList); + static void preparePrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables = false); -#endif /** * This method returns buffer pointer offset by given number of elements, wrt own data type diff --git a/libnd4j/blas/cpu/NDArray.cpp b/libnd4j/blas/cpu/NDArray.cpp index 1607c93aa..9a7271b28 100644 --- a/libnd4j/blas/cpu/NDArray.cpp +++ b/libnd4j/blas/cpu/NDArray.cpp @@ -182,22 +182,19 @@ void NDArray::synchronize(const char* msg) const { // no-op } -#if !defined(__JAVACPP_HACK__) && !defined(_JNI_IMPLEMENTATION_) -void NDArray::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { +void NDArray::prepareSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables) { // no-op } -void NDArray::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { +void NDArray::registerSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList) { // no-op } -void NDArray::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { +void NDArray::preparePrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables) { // no-op } -void NDArray::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { +void NDArray::registerPrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList) { // no-op } -#endif - void NDArray::syncShape() const { // no-op } diff --git a/libnd4j/blas/cuda/NDArray.cu b/libnd4j/blas/cuda/NDArray.cu index 4760cabd8..7d58803b3 100644 --- a/libnd4j/blas/cuda/NDArray.cu +++ b/libnd4j/blas/cuda/NDArray.cu @@ -230,10 +230,9 @@ void NDArray::synchronize(const char* msg) const { if (res != 0) throw std::runtime_error(msg + std::string(": synchronization failed !")); } -#if !defined(__JAVACPP_HACK__) && !defined(_JNI_IMPLEMENTATION_) //////////////////////////////////////////////////////////////////////// -void NDArray::prepareSpecialUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { +void NDArray::prepareSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables) { for (const auto& a : readList) if(a != nullptr) @@ -249,7 +248,7 @@ void NDArray::prepareSpecialUse(const std::vector& writeList, co } //////////////////////////////////////////////////////////////////////// -void NDArray::registerSpecialUse(const std::vector& writeList, const std::vector& readList) { +void NDArray::registerSpecialUse(const std::initializer_list& writeList, const std::initializer_list& readList) { for (const auto& p : readList) if(p != nullptr) @@ -261,7 +260,7 @@ void NDArray::registerSpecialUse(const std::vector& writeList, c } //////////////////////////////////////////////////////////////////////// -void NDArray::preparePrimaryUse(const std::vector& writeList, const std::vector& readList, bool synchronizeWritables) { +void NDArray::preparePrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList, bool synchronizeWritables) { for (const auto& a : readList) if(a != nullptr) @@ -277,7 +276,7 @@ void NDArray::preparePrimaryUse(const std::vector& writeList, co } //////////////////////////////////////////////////////////////////////// -void NDArray::registerPrimaryUse(const std::vector& writeList, const std::vector& readList) { +void NDArray::registerPrimaryUse(const std::initializer_list& writeList, const std::initializer_list& readList) { for (const auto& p : readList) if(p != nullptr) @@ -288,8 +287,6 @@ void NDArray::registerPrimaryUse(const std::vector& writeList, c p->tickWriteHost(); } -#endif - ////////////////////////////////////////////////////////////////////////// void NDArray::syncShape() const { cudaMemcpy(getSpecialShapeInfo(), getShapeInfo(), shape::shapeInfoByteLength(getShapeInfo()), cudaMemcpyHostToDevice);