[WIP] cross-device migrations (#134)

* two more tests fixed

Signed-off-by: raver119 <raver119@gmail.com>

* CUDA device afinity tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* minor tweaks

Signed-off-by: raver119 <raver119@gmail.com>

* prepareAction/registerAction for CustomOps

Signed-off-by: raver119 <raver119@gmail.com>

* lazy allocate host bufer before relocation

Signed-off-by: raver119 <raver119@gmail.com>

* one special test for migration in cpp

Signed-off-by: raver119 <raver119@gmail.com>

* tests update for msvc

Signed-off-by: raver119 <raver119@gmail.com>

* logging

Signed-off-by: raver119 <raver119@gmail.com>

* stick to old col2im impl

Signed-off-by: raver119 <raver119@gmail.com>

* cudaStreams reorganization

Signed-off-by: raver119 <raver119@gmail.com>

* buffer size fix

Signed-off-by: raver119 <raver119@gmail.com>

* c++ data migration

Signed-off-by: raver119 <raver119@gmail.com>

* fix CropAndResize test

Signed-off-by: raver119 <raver119@gmail.com>

* - minor improvment

Signed-off-by: Yurii <yurii@skymind.io>
master
raver119 2019-08-20 18:52:41 +03:00 committed by GitHub
parent 23c8738d4a
commit 269d508ba5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
35 changed files with 576 additions and 258 deletions

View File

@ -39,6 +39,7 @@
#include <ShapeDescriptor.h>
#include <helpers/ConstantShapeHelper.h>
#include <array/DataBuffer.h>
#include <AffinityManager.h>
namespace nd4j {
@ -143,6 +144,11 @@ namespace nd4j {
*/
nd4j::DataType _dataType = FLOAT32;
/**
* deviceID where this NDArray belongs to
*/
int _deviceId = AffinityManager::currentDeviceId();
template<typename T>
std::string toStringValue(T value);

View File

@ -55,7 +55,19 @@ void* NDArray::getPlatformBuffer() const { return getSpecialBuffer(); }
Nd4jLong* NDArray::getPlatformShapeInfo() const { return getSpecialShapeInfo(); }
Nd4jLong* NDArray::platformShapeInfo() { return specialShapeInfo(); }
void NDArray::syncToDevice() const { _buffer->syncToSpecial(); }
void NDArray::syncToDevice() const {
auto currentDeviceId = AffinityManager::currentDeviceId();
if (currentDeviceId != _deviceId) {
// first of all we update shapeInfo
const_cast<NDArray*>(this)->setShapeInfo(this->getShapeInfo());
// now we actually migrate data buffer
_buffer->migrate();
}
_buffer->syncToSpecial();
}
void NDArray::syncToHost() const { _buffer->syncToPrimary(getContext()); }
void NDArray::tickWriteHost() const { _buffer->writePrimary(); }
void NDArray::tickWriteDevice() const { _buffer->writeSpecial(); }

View File

@ -31,6 +31,7 @@
#include <AffinityManager.h>
#include <exceptions/datatype_exception.h>
#include <exceptions/cuda_exception.h>
#include <helpers/CudaLaunchHelper.h>
// FIXME: we need cuda-specific implementations
#include <GraphExecutioner.h>
@ -965,11 +966,7 @@ int registerEvent(Nd4jPointer event, Nd4jPointer stream) {
}
int setDevice(int deviceId) {
auto dZ = cudaSetDevice(deviceId);
checkCudaErrors(dZ);
if (dZ != 0)
throw std::runtime_error("cudaSetDevice(...) failed");
AffinityManager::setCurrentDevice(deviceId);
return 1;
}
@ -1024,16 +1021,15 @@ Nd4jLong getDeviceTotalMemory(int device) {
}
int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
return memcpyAsync(dst, src, size, flags, reserved);
}
int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(reserved);
auto pStream = reinterpret_cast<cudaStream_t *>(reserved);
cudaMemcpyKind kind;
DEBUG_KERNEL(pStream, 0);
//nd4j::DebugHelper::checkErrorCode(pStream, "Preliminary sync failed");
switch (flags) {
case 0: {
@ -1047,25 +1043,22 @@ int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4j
case 2: {
kind = cudaMemcpyDeviceToHost;
}
break;
case 3: {
kind = cudaMemcpyDeviceToDevice;
}
kind = cudaMemcpyDeviceToDevice;
}
break;
default: {
printf("UNDEFINED MEMCPY!\n");
break;
}
default:
throw nd4j::cuda_exception::build("UNDEFINED MEMCPY!\n", 119);
}
cudaError_t dZ = cudaMemcpyAsync(reinterpret_cast<void *>(dst), const_cast<const void *>(reinterpret_cast<void *>(src)), static_cast<size_t>(size), kind, *pStream);
auto dZ = cudaMemcpyAsync(reinterpret_cast<void *>(dst), const_cast<const void *>(reinterpret_cast<void *>(src)), static_cast<size_t>(size), kind, *pStream);
//auto dZ = cudaMemcpy(reinterpret_cast<void *>(dst), const_cast<const void *>(reinterpret_cast<void *>(src)), static_cast<size_t>(size), kind);
if (dZ != 0) {
checkCudaErrors(dZ);
printf("Failed on [%lu] -> [%lu], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast<int>(dZ));
printf("Failed on [%lu] -> [%lu], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast<int>(dZ));
fflush(stdout);
fflush(stderr);
throw std::runtime_error("cudaMemcpyAsync(...) failed");
//return 0L;
throw nd4j::cuda_exception::build("cudaMemcpyAsync(...) failed", dZ);
}
return 1;
@ -3256,7 +3249,7 @@ void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len) {
auto e = cudaStreamSynchronize(stream);
if (e != 0)
throw std::runtime_error("tryPointer failed");
throw nd4j::cuda_exception::build("tryPointer failed", e);
cudaStreamDestroy(stream);
}

View File

@ -105,6 +105,8 @@ class ND4J_EXPORT DataBuffer {
bool isPrimaryActual() const;
bool isSpecialActual() const;
void migrate();
template <typename T> FORCEINLINE T* primaryAsT();
template <typename T> FORCEINLINE T* specialAsT();

View File

@ -96,6 +96,11 @@ void DataBuffer::allocateSpecial() {
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::migrate() {
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::writePrimary() const { }
void DataBuffer::writeSpecial() const { }

View File

@ -173,6 +173,22 @@ void DataBuffer::setToZeroBuffers(const bool both) {
}
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::migrate() {
memory::Workspace* newWorkspace = nullptr;
void* newBuffer;
ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t);
cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice);
if (_isOwnerSpecial) {
// now we're releasing original buffer
RELEASE_SPECIAL(_specialBuffer, _workspace);
}
_isOwnerSpecial = true;
_specialBuffer = newBuffer;
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::writePrimary() const { _writePrimary = ++_counter; }
void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; }

View File

@ -27,10 +27,12 @@
namespace nd4j {
class ND4J_EXPORT ContextBuffers {
private:
void* _reductionPointer;
void* _scalarPointer;
void* _allocationPointer;
bool _allocated = true;
void* _reductionPointer = nullptr;
void* _scalarPointer = nullptr;
void* _allocationPointer = nullptr;
void* _execStream = nullptr;
void* _specialStream = nullptr;
bool _allocated = false;
int _deviceId = -1;
@ -44,6 +46,9 @@ namespace nd4j {
void* scalarBuffer();
void* allocationBuffer();
void* execStream();
void* specialStream();
void setReductionBuffer(void* pointer);
void setScalarBuffer(void* pointer);
void setAllocationBuffer(void* pointer);

View File

@ -52,8 +52,6 @@ class ND4J_EXPORT LaunchContext {
#ifndef __JAVACPP_HACK__
cudaStream_t* _cudaStream = nullptr;
cudaStream_t* _cudaSpecialStream = nullptr;
void* _cublasHandle = nullptr;
void* _cusolverHandle = nullptr;
@ -102,6 +100,9 @@ class ND4J_EXPORT LaunchContext {
static LaunchContext* defaultContext();
static void swapContextBuffers(ContextBuffers &buffers);
};
}

View File

@ -71,4 +71,12 @@ namespace nd4j {
int ContextBuffers::deviceId() {
return _deviceId;
}
}
void* ContextBuffers::execStream() {
return _execStream;
}
void* ContextBuffers::specialStream() {
return _specialStream;
}
}

View File

@ -53,4 +53,8 @@ namespace nd4j {
// return context for current device
return LaunchContext::_contexts[0].get();
}
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
//
}
}

View File

@ -21,6 +21,7 @@
#include <logger.h>
#include <execution/AffinityManager.h>
#include <exceptions/cuda_exception.h>
#include <LaunchContext.h>
thread_local int globalThreadToDevice = -1;
@ -98,10 +99,13 @@ namespace nd4j {
if (res != 0)
throw cuda_exception::build("cudaSetDevice failed", res);
auto previousDeviceId = globalThreadToDevice;
// update thread-device affinity
globalThreadToDevice = deviceId;
// TODO: update context buffers?
ContextBuffers newBuffers;
LaunchContext::swapContextBuffers(newBuffers);
}
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);

View File

@ -19,6 +19,7 @@
//
#include <execution/ContextBuffers.h>
#include <exceptions/cuda_exception.h>
#include <logger.h>
#include <AffinityManager.h>
@ -45,6 +46,18 @@ namespace nd4j {
if (_allocationPointer != nullptr)
cudaFree(_reductionPointer);
auto _cudaStream = reinterpret_cast<cudaStream_t*>(_execStream);
auto _cudaSpecialStream = reinterpret_cast<cudaStream_t*>(_specialStream);
cudaStreamSynchronize(*_cudaStream);
cudaStreamSynchronize(*_cudaSpecialStream);
cudaStreamDestroy(*_cudaStream);
cudaStreamDestroy(*_cudaSpecialStream);
delete _cudaStream;
delete _cudaSpecialStream;
}
}
@ -70,6 +83,19 @@ namespace nd4j {
if (res != 0)
throw std::runtime_error("_allocationPointer allocation failed");
_execStream = new cudaStream_t();
_specialStream = new cudaStream_t();
if (nullptr == _execStream || nullptr == _specialStream)
throw std::runtime_error("Failed to allocate memory for new CUDA stream");
res = cudaStreamCreate(reinterpret_cast<cudaStream_t*>(_execStream));
if (res != 0)
throw cuda_exception::build("Failed to create default CUDA stream with launch context", res);
res = cudaStreamCreate(reinterpret_cast<cudaStream_t*>(_specialStream));
if (res != 0)
throw cuda_exception::build("Failed to create special CUDA stream with launch context", res);
_allocated = true;
}
@ -113,4 +139,18 @@ namespace nd4j {
int ContextBuffers::deviceId() {
return _deviceId;
}
void* ContextBuffers::execStream() {
if (_execStream == nullptr)
initialize();
return _execStream;
}
void* ContextBuffers::specialStream() {
if (_specialStream == nullptr)
initialize();
return _specialStream;
}
}

View File

@ -35,8 +35,8 @@ namespace nd4j {
////////////////////////////////////////////////////////////////////////
LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) {
_cudaStream = cudaStream;
_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream;
//_cudaStream = cudaStream;
//_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream;
//_reductionPointer = reductionPointer;
//_scalarPointer = scalarPointer;
//_allocationPointer = allocationPointer;
@ -46,14 +46,7 @@ LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCuda
LaunchContext::~LaunchContext() {
if (_isAllocated) {
cudaStreamSynchronize(*_cudaStream);
cudaStreamSynchronize(*_cudaSpecialStream);
cudaStreamDestroy(*_cudaStream);
cudaStreamDestroy(*_cudaSpecialStream);
delete _cudaStream;
delete _cudaSpecialStream;
}
}
@ -64,32 +57,16 @@ LaunchContext::LaunchContext() {
_deviceID = 0;
_isAllocated = true;
_cudaStream = new cudaStream_t();
_cudaSpecialStream = new cudaStream_t();
if (nullptr == _cudaStream || nullptr == _cudaSpecialStream)
throw std::runtime_error("Failed to allocate memory for new CUDA stream");
cudaError_t err = cudaStreamCreate(_cudaStream);
if (err != 0)
throw cuda_exception::build("Failed to create default CUDA stream with launch context", err);
err = cudaStreamCreate(_cudaSpecialStream);
if (err != 0)
throw cuda_exception::build("Failed to create special CUDA stream with launch context", err);
_cublasHandle = CublasHelper::getInstance()->handle();
_cusolverHandle = CublasHelper::getInstance()->solver();
auto res = cudaStreamSynchronize(*_cudaStream);
if (res != 0)
throw cuda_exception::build("Initial sync failed", res);
}
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
_isAllocated = false;
_cudaStream = reinterpret_cast<cudaStream_t*>(cudaStream);
_cudaSpecialStream = reinterpret_cast<cudaStream_t*>(cudaStream);
//_cudaStream = reinterpret_cast<cudaStream_t*>(cudaStream);
// _cudaSpecialStream = reinterpret_cast<cudaStream_t*>(cudaStream);
//_reductionPointer = reductionPointer;
//_scalarPointer = scalarPointer;
//_allocationPointer = reinterpret_cast<int *>(allocationPointer);
@ -148,11 +125,11 @@ LaunchContext::LaunchContext() {
};
cudaStream_t* LaunchContext::getCudaStream() const {
return _cudaStream;
return reinterpret_cast<cudaStream_t*>(contextBuffers.execStream());
};
cudaStream_t* LaunchContext::getCudaSpecialStream() const {
return _cudaSpecialStream;
return reinterpret_cast<cudaStream_t*>(contextBuffers.specialStream());;
};
@ -169,14 +146,18 @@ LaunchContext::LaunchContext() {
};
void LaunchContext::setCudaStream(cudaStream_t* cudaStream) {
_cudaStream = cudaStream;
//_cudaStream = cudaStream;
};
void LaunchContext::setCudaSpecialStream(cudaStream_t* cudaStream) {
_cudaSpecialStream = cudaStream;
//_cudaSpecialStream = cudaStream;
};
void LaunchContext::setCublasHandle(void *handle) {
_cublasHandle = handle;
};
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
contextBuffers = buffers;
};
}

View File

@ -22,12 +22,13 @@
#include <exceptions/cuda_exception.h>
#include <ShapeDescriptor.h>
#include <ShapeBuilders.h>
#include <AffinityManager.h>
#include <ConstantHelper.h>
namespace nd4j {
ConstantShapeHelper::ConstantShapeHelper() {
auto numDevices = ConstantHelper::getNumberOfDevices();
auto numDevices = AffinityManager::numberOfDevices();
_cache.resize(numDevices);
for (int e = 0; e < numDevices; e++) {
@ -54,7 +55,7 @@ namespace nd4j {
}
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
int deviceId = ConstantHelper::getCurrentDevice();
int deviceId = AffinityManager::currentDeviceId();
_mutex.lock();
@ -83,7 +84,7 @@ namespace nd4j {
bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) {
bool result;
auto deviceId = ConstantHelper::getCurrentDevice();
auto deviceId = AffinityManager::currentDeviceId();
_mutex.lock();
if (_cache[deviceId].count(descriptor) == 0)

View File

@ -21,13 +21,14 @@
#include "../ConstantTadHelper.h"
#include <TAD.h>
#include <ConstantHelper.h>
#include <AffinityManager.h>
#include <exceptions/cuda_exception.h>
#include <execution/LaunchContext.h>
#include <ShapeUtils.h>
namespace nd4j {
ConstantTadHelper::ConstantTadHelper() {
auto numDevices = ConstantHelper::getNumberOfDevices();
auto numDevices = AffinityManager::numberOfDevices();
for (int e = 0; e < numDevices; e++) {
std::map<TadDescriptor, TadPack> pack;
@ -61,7 +62,7 @@ namespace nd4j {
}
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
const int deviceId = ConstantHelper::getCurrentDevice();
const int deviceId = AffinityManager::currentDeviceId();
_mutex.lock();

View File

@ -184,8 +184,8 @@ static void col2imCudaLauncher(const int blocksPerGrid, const int threadsPerBloc
void* image, const Nd4jLong* imShapeInfo,
const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
// col2imCuda2<T><<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW);
col2imCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW);
col2imCuda2<T><<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW);
//col2imCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW);
}
//////////////////////////////////////////////////////////////////////////

View File

@ -135,8 +135,9 @@ namespace helpers {
auto step = blockDim.x * gridDim.x;
for (int e = tid; e < len; e += step) {
if (output[shape::getIndexOffset(e, outputShape, len)] != T(0.))
output[shape::getIndexOffset(e, outputShape, len)] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue);
const auto zOffset = shape::getIndexOffset(e, outputShape, len);
if (output[zOffset] != T(0.))
output[zOffset] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue);
}
}

View File

@ -1050,14 +1050,14 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test1) {
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{ 0.226, 0.343, 0.46 , 0.577, 1.172, 1.46 , 1.748, 2.036, 1.892, 2.288, 2.684, 3.08 , 1.284, 1.581, 1.878, 2.175, 4.458, 5.133, 5.808, 6.483, 6.186, 7.023, 7.86 , 8.697,
3.39 , 3.93 , 4.47 , 5.01 , 9.642, 10.803, 11.964, 13.125,11.37 , 12.693, 14.016, 15.339, 5.266, 5.707, 6.148, 6.589,12.98 , 13.916, 14.852, 15.788,14.564, 15.608, 16.652, 17.696,
3.25 , 4.015, 4.78 , 5.545, 9.812, 11.396, 12.98 , 14.564,10.532, 12.224, 13.916, 15.608, 9.708, 10.977, 12.246, 13.515,25.194, 27.813, 30.432, 33.051,26.922, 29.703, 32.484, 35.265,
11.814, 13.326, 14.838, 16.35 ,30.378, 33.483, 36.588, 39.693,32.106, 35.373, 38.64 , 41.907,13.474, 14.563, 15.652, 16.741,31.988, 34.22 , 36.452, 38.684,33.572, 35.912, 38.252, 40.592});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{ 0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f,
3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f,11.37f, 12.693f, 14.016f, 15.339f, 5.266f, 5.707f, 6.148f, 6.589f,12.98f, 13.916f, 14.852f, 15.788f,14.564f, 15.608f, 16.652f, 17.696f,
3.25f, 4.015f, 4.78f, 5.545f, 9.812f, 11.396f, 12.98f, 14.564f,10.532f, 12.224f, 13.916f, 15.608f, 9.708f, 10.977f, 12.246f, 13.515f,25.194f, 27.813f, 30.432f, 33.051f,26.922f, 29.703f, 32.484f, 35.265f,
11.814f, 13.326f, 14.838f, 16.35f,30.378f, 33.483f, 36.588f, 39.693f,32.106f, 35.373f, 38.64f, 41.907f,13.474f, 14.563f, 15.652f, 16.741f,31.988f, 34.22f, 36.452f, 38.684f,33.572f, 35.912f, 38.252f, 40.592f});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC},{14.4 , 14.76, 15.12,14.4 , 14.76, 15.12,14.4 , 14.76, 15.12,14.4 , 14.76, 15.12, 9.24, 9.48, 9.72, 9.24, 9.48, 9.72, 9.24, 9.48, 9.72, 9.24, 9.48, 9.72,
17.04, 17.52, 18. ,17.04, 17.52, 18. ,17.04, 17.52, 18. ,17.04, 17.52, 18. ,10.88, 11.2 , 11.52,10.88, 11.2 , 11.52,10.88, 11.2 , 11.52,10.88, 11.2 , 11.52,
11.16, 11.52, 11.88,11.16, 11.52, 11.88,11.16, 11.52, 11.88,11.16, 11.52, 11.88, 7.08, 7.32, 7.56, 7.08, 7.32, 7.56, 7.08, 7.32, 7.56, 7.08, 7.32, 7.56});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC},{14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f,
17.04f, 17.52f, 18.f,17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,
11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f});
// auto expGradB('c', {oC},{});
input = 2.;
@ -1093,14 +1093,14 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) {
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{ 0.014,0.032, 0.05 , 0.068,0.118,0.181, 0.244, 0.307,0.212,0.257, 0.302, 0.347,0.208,0.298, 0.388, 0.478,1.028,1.262, 1.496, 1.73 ,1.036,1.18 , 1.324, 1.468,
0.928,1.018, 1.108, 1.198,2.9 ,3.134, 3.368, 3.602,2.188,2.332, 2.476, 2.62 ,1.202,1.274, 1.346, 1.418,3.142,3.313, 3.484, 3.655,2.048,2.147, 2.246, 2.345,
0.086,0.212, 0.338, 0.464,0.694,0.973, 1.252, 1.531,0.716,0.869, 1.022, 1.175,1.216,1.522, 1.828, 2.134,3.908,4.574, 5.24 , 5.906,2.908,3.268, 3.628, 3.988,
3.664,3.97 , 4.276, 4.582,9.236,9.902,10.568,11.234,5.788,6.148, 6.508, 6.868,3.002,3.182, 3.362, 3.542,7.174,7.561, 7.948, 8.335,4.28 ,4.487, 4.694, 4.901});
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f,0.118f,0.181f, 0.244f, 0.307f,0.212f,0.257f, 0.302f, 0.347f,0.208f,0.298f, 0.388f, 0.478f,1.028f,1.262f, 1.496f, 1.73f,1.036f,1.18f, 1.324f, 1.468f,
0.928f,1.018f, 1.108f, 1.198f,2.9f,3.134f, 3.368f, 3.602f,2.188f,2.332f, 2.476f, 2.62f, 1.202f,1.274f, 1.346f, 1.418f,3.142f,3.313f, 3.484f, 3.655f,2.048f,2.147f, 2.246f, 2.345f,
0.086f,0.212f, 0.338f, 0.464f,0.694f,0.973f, 1.252f, 1.531f,0.716f,0.869f, 1.022f, 1.175f,1.216f,1.522f, 1.828f, 2.134f,3.908f,4.574f, 5.24f, 5.906f,2.908f,3.268f, 3.628f, 3.988f,
3.664f,3.97f, 4.276f, 4.582f,9.236f,9.902f,10.568f,11.234f,5.788f,6.148f, 6.508f, 6.868f,3.002f,3.182f, 3.362f, 3.542f,7.174f,7.561f, 7.948f, 8.335f,4.28f,4.487f, 4.694f, 4.901f});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC},{1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,
1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,
1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16});
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC},{1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,
1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,
1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f});
// auto expGradB('c', {oC},{});
input = 2.;

File diff suppressed because one or more lines are too long

View File

@ -1092,15 +1092,15 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_06) {
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, NTH_Element_Test_7) {
NDArray input = NDArrayFactory::create<float>('c', {2, 3, 4}, {0.7788, 0.8012, 0.7244, 0.2309,
0.7271, 0.1804, 0.5056, 0.8925,
0.5461, 0.9234, 0.0856, 0.7938,
NDArray input = NDArrayFactory::create<float>('c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f,
0.7271f, 0.1804f, 0.5056f, 0.8925f,
0.5461f, 0.9234f, 0.0856f, 0.7938f,
0.6591, 0.5555, 0.1596, 0.3087,
0.1548, 0.4695, 0.9939, 0.6113,
0.6765, 0.1800, 0.6750, 0.2246});
0.6591f, 0.5555f, 0.1596f, 0.3087f,
0.1548f, 0.4695f, 0.9939f, 0.6113f,
0.6765f, 0.1800f, 0.6750f, 0.2246f});
NDArray n = NDArrayFactory::create<int>(2);
NDArray exp = NDArrayFactory::create<float>('c', {2,3}, {0.7788, 0.7271, 0.7938, 0.5555, 0.6113, 0.675});
NDArray exp = NDArrayFactory::create<float>('c', {2,3}, {0.7788f, 0.7271f, 0.7938f, 0.5555f, 0.6113f, 0.675f});
//input.linspace(1.f);
@ -1119,15 +1119,15 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_7) {
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, NTH_Element_Test_8) {
NDArray input = NDArrayFactory::create<float>('c', {2, 3, 4}, {0.7788, 0.8012, 0.7244, 0.2309,
0.7271, 0.1804, 0.5056, 0.8925,
0.5461, 0.9234, 0.0856, 0.7938,
NDArray input = NDArrayFactory::create<float>('c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f,
0.7271f, 0.1804f, 0.5056f, 0.8925f,
0.5461f, 0.9234f, 0.0856f, 0.7938f,
0.6591, 0.5555, 0.1596, 0.3087,
0.1548, 0.4695, 0.9939, 0.6113,
0.6765, 0.1800, 0.6750, 0.2246});
0.6591f, 0.5555f, 0.1596f, 0.3087f,
0.1548f, 0.4695f, 0.9939f, 0.6113f,
0.6765f, 0.1800f, 0.6750f, 0.2246f});
NDArray n = NDArrayFactory::create<int>(2);
NDArray exp = NDArrayFactory::create<float>('c', {2,3}, {0.7244, 0.5056, 0.5461, 0.3087, 0.4695, 0.2246});
NDArray exp = NDArrayFactory::create<float>('c', {2,3}, {0.7244f, 0.5056f, 0.5461f, 0.3087f, 0.4695f, 0.2246f});
//input.linspace(1.f);
@ -1359,10 +1359,10 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test10) {
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
NDArray input = NDArrayFactory::create<float>('c', {1, 2,3,4});
NDArray input = NDArrayFactory::create<double>('c', {1, 2,3,4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<float>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10.,
8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12.,
9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6,
@ -1416,10 +1416,10 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) {
NDArray input = NDArrayFactory::create<float>('c', {1, 2,3,4});
NDArray input = NDArrayFactory::create<double>('c', {1, 2,3,4});
NDArray size = NDArrayFactory::create<int>({10, 10});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<float>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10.,
8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12.,
9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6,
@ -1472,10 +1472,10 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) {
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) {
NDArray input = NDArrayFactory::create<float>('c', {1, 2,3,4});
NDArray input = NDArrayFactory::create<double>('c', {1, 2,3,4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<float>('c', {1, 10, 10, 4},
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4},
{ 1., 2., 3., 4. ,
1.8888888, 2.8888888, 3.8888888, 4.888889,
2.7777777, 3.7777777, 4.7777777, 5.7777777,
@ -1602,9 +1602,9 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) {
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) {
NDArray input = NDArrayFactory::create<float>('c', {1, 2,3,4});
NDArray input = NDArrayFactory::create<double>('c', {1, 2,3,4});
NDArray size = NDArrayFactory::create<int>({10, 10});
NDArray expected = NDArrayFactory::create<float>('c', {1, 10, 10, 4},
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4},
{ 1., 2., 3., 4. ,
1.8888888, 2.8888888, 3.8888888, 4.888889,
2.7777777, 3.7777777, 4.7777777, 5.7777777,
@ -1750,10 +1750,10 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) {
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
NDArray input = NDArrayFactory::create<float>('c', {1, 2, 3, 4});
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
NDArray expected = NDArrayFactory::create<float>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
NDArray expected = NDArrayFactory::create<double>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
1, 2, 3, 4,
5, 6, 7, 8,
5, 6, 7, 8,
@ -1926,8 +1926,8 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) {
////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) {
int axis = 0;
NDArray images = NDArrayFactory::create<float>('c', {1,2,2,1}, {1,2,3,4});
NDArray boxes = NDArrayFactory::create<float>('c', {1,4}, {0,0,1,1});
NDArray images = NDArrayFactory::create<float>('c', {1,2,2,1}, {1.f, 2.f, 3.f, 4.f});
NDArray boxes = NDArrayFactory::create<float>('c', {1,4}, {0.f, 0.f, 1.f, 1.f});
NDArray boxI = NDArrayFactory::create<int>('c', {1}, {axis});
NDArray cropSize = NDArrayFactory::create<int>({1, 1});

View File

@ -3238,11 +3238,11 @@ TEST_F(DeclarableOpsTests8, Test_Moments_7) {
////////////////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_01) {
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 2, 5}, { 1, 2., 3, 4, 5,
6., 7., 8, 9, 10}
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 2, 5}, { 1.f, 2.f, 3.f, 4.f, 5.f,
6.f, 7.f, 8.f, 9.f, 10.f}
);
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 2, 5}, {0.2581989 , 0.3592106 , 0.40089184, 0.53935987, 0.70014, 0.4898979 , 0.46056613, 0.43971977, 0.5240003 , 0.6375767 }// 0.72760683, 0.4850712, 0.5848977, 0.67488194,
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 2, 5}, {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, 0.4898979f, 0.46056613f, 0.43971977f, 0.5240003f, 0.6375767f}// 0.72760683, 0.4850712, 0.5848977, 0.67488194,
// 0.7581754, 0.58321184, 0.86747235, 0.4048204}
);
@ -3262,10 +3262,10 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_01) {
////////////////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_02) {
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 6}, { 1, 2., 3, 4, 5, 6});
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 6}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 6}, {
0.2581989 , 0.3592106 , 0.40089184, 0.4193139, 0.5360563, 0.67936623}
0.2581989f, 0.3592106f, 0.40089184f, 0.4193139f, 0.5360563f, 0.67936623f}
);
nd4j::ops::lrn op;
@ -3284,8 +3284,8 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_02) {
////////////////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_03) {
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 10}, { 1, 2., 3, 4, 5, 6, 7, 8, 9, 10});
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 10}, {0.10425719, 0.16843036, 0.2095291 , 0.23652494, 0.25449327,0.3053919 , 0.35675305, 0.4098524 , 0.46662825, 0.52999896});
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 10}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f});
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 10}, {0.10425719f, 0.16843036f, 0.2095291f, 0.23652494f, 0.25449327f, 0.3053919f, 0.35675305f, 0.4098524f, 0.46662825f, 0.52999896f});
nd4j::ops::lrn op;
auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {5}, {}, false, nd4j::DataType::DOUBLE);
@ -3303,17 +3303,17 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_03) {
////////////////////////////////////////////////////////////////////////////////
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_1) {
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, { 5.5, 0., 0.3, 5.5,
8.6, 0., 0., 0.4,
1.5, 1., 1.3, 1.5,
2.6, 2., 3., 1.4}
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f,
8.6f, 0.f, 0.f, 0.4f,
1.5f, 1.f, 1.3f, 1.5f,
2.6f, 2.f, 3.f, 1.4f}
);
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {
0.98386997, 0., 0.05358852, 0.9824562,
0.99330735, 0., 0., 0.37139067,
0.72760683, 0.4850712, 0.5848977, 0.67488194,
0.7581754, 0.58321184, 0.86747235, 0.4048204}
0.98386997f, 0.f, 0.05358852f, 0.9824562f,
0.99330735f, 0.f, 0.f, 0.37139067f,
0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f,
0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}
);
nd4j::ops::lrn op;
@ -3336,61 +3336,61 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_2) {
x.linspace(1);
auto exp = NDArrayFactory::create<TypeParam>('c', {3, 3, 5, 5}, {
0.2581989 ,0.3592106 , 0.40089184, 0.53935987, 0.70014,
0.4898979 ,0.46056613, 0.43971977, 0.5240002 , 0.6375767,
0.5274096 ,0.47771242, 0.4443308 , 0.5163977 , 0.61701745,
0.5424508 ,0.48452914, 0.44570294, 0.5123918 , 0.6068971,
0.5505386 ,0.4881662 , 0.4462865 , 0.5099462 , 0.60088515,
0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f,
0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f,
0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f,
0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f,
0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f,
0.5555859 , 0.49042296, 0.44658744, 0.5083028 , 0.59690416,
0.55903524, 0.4919585 , 0.44676256, 0.5071239 , 0.59407425,
0.5615412 , 0.49307042, 0.44687328, 0.50623745, 0.5919596 ,
0.56344414, 0.49391258, 0.4469477 , 0.5055468 , 0.59031945,
0.56493837, 0.49457246, 0.4470002 , 0.5049936 , 0.5890103 ,
0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f,
0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f,
0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f,
0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f,
0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f,
0.56614274, 0.49510333, 0.44703856, 0.50454074, 0.5879411 ,
0.567134 , 0.49553978, 0.4470674 , 0.504163 , 0.5870515 ,
0.5679643 , 0.4959048 , 0.44708967, 0.5038433 , 0.5862998 ,
0.56866974, 0.4962146 , 0.44710726, 0.5035692 , 0.58565617,
0.56927663, 0.49648085, 0.4471213 , 0.5033315 , 0.5850988 ,
0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f,
0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f,
0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f,
0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f,
0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f,
0.56980413, 0.49671215, 0.44713274, 0.50312346, 0.58461165,
0.57026696, 0.49691492, 0.4471422 , 0.50293994, 0.58418214,
0.5706764 , 0.49709415, 0.44715008, 0.5027767 , 0.5838005 ,
0.571041 , 0.4972537 , 0.44715673, 0.50263065, 0.58345926,
0.57136786, 0.49739665, 0.44716236, 0.5024992 , 0.58315235,
0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f,
0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f,
0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f,
0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f,
0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f,
0.5716625 , 0.49752548, 0.4471672 , 0.5023803, 0.5828747 ,
0.5719295 , 0.49764213, 0.44717142, 0.5022721, 0.5826225 ,
0.57217246, 0.49774826, 0.44717506, 0.5021734, 0.58239233,
0.5723947 , 0.4978453 , 0.44717824, 0.5020829, 0.58218133,
0.57259864, 0.49793428, 0.44718108, 0.5019997, 0.5819874 ,
0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f,
0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f,
0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f,
0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f,
0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f,
0.5727864 , 0.49801624, 0.44718358, 0.5019227, 0.5818083 ,
0.57296 , 0.49809194, 0.44718578, 0.5018515, 0.5816426 ,
0.5731208 , 0.49816203, 0.44718775, 0.5017854, 0.58148885,
0.57327026, 0.49822718, 0.4471895 , 0.5017239, 0.5813457 ,
0.57340944, 0.49828786, 0.44719115, 0.5016664, 0.581212 ,
0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f,
0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f,
0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f,
0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f,
0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f,
0.57353944, 0.4983446 , 0.44719255, 0.50161266, 0.58108705,
0.5736612 , 0.49839762, 0.4471939 , 0.50156236, 0.5809699 ,
0.5737754 , 0.4984474 , 0.44719502, 0.501515 , 0.58085984,
0.5738828 , 0.49849418, 0.4471962 , 0.50147045, 0.5807564 ,
0.5739839 , 0.49853817, 0.44719717, 0.5014284 , 0.5806588 ,
0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f,
0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f,
0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f,
0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f,
0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f,
0.5740793 , 0.49857965, 0.4471981 , 0.5013887 , 0.5805666 ,
0.5741694 , 0.49861887, 0.44719887, 0.50135124, 0.58047944,
0.57425463, 0.49865603, 0.44719967, 0.5013157 , 0.5803969 ,
0.5743354 , 0.4986912 , 0.44720036, 0.5012819 , 0.5803186 ,
0.57441217, 0.49872455, 0.44720104, 0.5012499 , 0.58024424,
0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f,
0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f,
0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f,
0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f,
0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f,
0.57448506, 0.4987563 , 0.44720164, 0.5012194 , 0.58017343,
0.57455444, 0.4987865 , 0.4472022 , 0.5011904 , 0.5801061,
0.57462054, 0.49881527, 0.44720277, 0.5011627 , 0.5800419,
0.57468355, 0.49884263, 0.44720328, 0.50113624, 0.5799805,
0.57474375, 0.49886885, 0.44720373, 0.50111103, 0.5799219 }
0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f,
0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f,
0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f,
0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f,
0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f }
);
//
nd4j::ops::lrn op;
@ -3413,61 +3413,61 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_3) {
x.linspace(1);
auto exp = NDArrayFactory::create<TypeParam>('c', {3, 3, 5, 5}, {
0.2581989 ,0.3592106 , 0.40089184, 0.53935987, 0.70014,
0.4898979 ,0.46056613, 0.43971977, 0.5240002 , 0.6375767,
0.5274096 ,0.47771242, 0.4443308 , 0.5163977 , 0.61701745,
0.5424508 ,0.48452914, 0.44570294, 0.5123918 , 0.6068971,
0.5505386 ,0.4881662 , 0.4462865 , 0.5099462 , 0.60088515,
0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f,
0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f,
0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f,
0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f,
0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f,
0.5555859 , 0.49042296, 0.44658744, 0.5083028 , 0.59690416,
0.55903524, 0.4919585 , 0.44676256, 0.5071239 , 0.59407425,
0.5615412 , 0.49307042, 0.44687328, 0.50623745, 0.5919596 ,
0.56344414, 0.49391258, 0.4469477 , 0.5055468 , 0.59031945,
0.56493837, 0.49457246, 0.4470002 , 0.5049936 , 0.5890103 ,
0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f,
0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f,
0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f,
0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f,
0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f,
0.56614274, 0.49510333, 0.44703856, 0.50454074, 0.5879411 ,
0.567134 , 0.49553978, 0.4470674 , 0.504163 , 0.5870515 ,
0.5679643 , 0.4959048 , 0.44708967, 0.5038433 , 0.5862998 ,
0.56866974, 0.4962146 , 0.44710726, 0.5035692 , 0.58565617,
0.56927663, 0.49648085, 0.4471213 , 0.5033315 , 0.5850988 ,
0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f,
0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f,
0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f,
0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f,
0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f,
0.56980413, 0.49671215, 0.44713274, 0.50312346, 0.58461165,
0.57026696, 0.49691492, 0.4471422 , 0.50293994, 0.58418214,
0.5706764 , 0.49709415, 0.44715008, 0.5027767 , 0.5838005 ,
0.571041 , 0.4972537 , 0.44715673, 0.50263065, 0.58345926,
0.57136786, 0.49739665, 0.44716236, 0.5024992 , 0.58315235,
0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f,
0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f,
0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f,
0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f,
0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f,
0.5716625 , 0.49752548, 0.4471672 , 0.5023803, 0.5828747 ,
0.5719295 , 0.49764213, 0.44717142, 0.5022721, 0.5826225 ,
0.57217246, 0.49774826, 0.44717506, 0.5021734, 0.58239233,
0.5723947 , 0.4978453 , 0.44717824, 0.5020829, 0.58218133,
0.57259864, 0.49793428, 0.44718108, 0.5019997, 0.5819874 ,
0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f,
0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f,
0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f,
0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f,
0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f,
0.5727864 , 0.49801624, 0.44718358, 0.5019227, 0.5818083 ,
0.57296 , 0.49809194, 0.44718578, 0.5018515, 0.5816426 ,
0.5731208 , 0.49816203, 0.44718775, 0.5017854, 0.58148885,
0.57327026, 0.49822718, 0.4471895 , 0.5017239, 0.5813457 ,
0.57340944, 0.49828786, 0.44719115, 0.5016664, 0.581212 ,
0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f,
0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f,
0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f,
0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f,
0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f,
0.57353944, 0.4983446 , 0.44719255, 0.50161266, 0.58108705,
0.5736612 , 0.49839762, 0.4471939 , 0.50156236, 0.5809699 ,
0.5737754 , 0.4984474 , 0.44719502, 0.501515 , 0.58085984,
0.5738828 , 0.49849418, 0.4471962 , 0.50147045, 0.5807564 ,
0.5739839 , 0.49853817, 0.44719717, 0.5014284 , 0.5806588 ,
0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f,
0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f,
0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f,
0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f,
0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f,
0.5740793 , 0.49857965, 0.4471981 , 0.5013887 , 0.5805666 ,
0.5741694 , 0.49861887, 0.44719887, 0.50135124, 0.58047944,
0.57425463, 0.49865603, 0.44719967, 0.5013157 , 0.5803969 ,
0.5743354 , 0.4986912 , 0.44720036, 0.5012819 , 0.5803186 ,
0.57441217, 0.49872455, 0.44720104, 0.5012499 , 0.58024424,
0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f,
0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f,
0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f,
0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f,
0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f,
0.57448506, 0.4987563 , 0.44720164, 0.5012194 , 0.58017343,
0.57455444, 0.4987865 , 0.4472022 , 0.5011904 , 0.5801061,
0.57462054, 0.49881527, 0.44720277, 0.5011627 , 0.5800419,
0.57468355, 0.49884263, 0.44720328, 0.50113624, 0.5799805,
0.57474375, 0.49886885, 0.44720373, 0.50111103, 0.5799219 }
0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f,
0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f,
0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f,
0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f,
0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f }
);
//
nd4j::ops::lrn op;

View File

@ -0,0 +1,68 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include "testlayers.h"
#include <array/ArrayOptions.h>
#include <AffinityManager.h>
#include <NDArray.h>
#include <NDArrayFactory.h>
#include <ops/declarable/headers/broadcastable.h>
#include <MmulHelper.h>
#include <thread>
using namespace nd4j;
class MultiDeviceTests : public testing::Test {
public:
};
void createArrays(int limit, std::vector<NDArray*> &arrays) {
auto deviceId = AffinityManager::currentDeviceId();
auto numDevices = AffinityManager::numberOfDevices();
for (int e = 0; e < limit; e++) {
auto value = deviceId * limit + e;
arrays[value] = NDArrayFactory::create_<float>('c', {10});
arrays[value]->assign(value);
//nd4j_printf("device_%i; value: [%i]; mean: [%f]\n", deviceId, value, arrays[value]->meanNumber().e<float>(0));
}
}
TEST_F(MultiDeviceTests, test_multi_device_migration_1) {
auto deviceId = AffinityManager::currentDeviceId();
auto numDevices = AffinityManager::numberOfDevices();
auto numArrays = 10;
std::vector<NDArray*> arrays(numDevices * numArrays);
// filling list of arrays on multiple threads
for (int e = 0; e < numDevices; e++) {
std::thread t1(createArrays, numArrays, std::ref(arrays));
t1.join();
}
// at this moment all arrays are build, so we can test migration
for (int e = 0; e < arrays.size(); e++) {
ASSERT_NEAR((float) e, arrays[e]->meanNumber().e<float>(0), 1e-5f);
delete arrays[e];
}
}

View File

@ -438,6 +438,7 @@ public class AtomicAllocator implements Allocator {
Long allocId = objectsTracker.getAndIncrement();
point.setObjectId(allocId);
point.setConstant(true);
point.setDeviceId(Nd4j.getAffinityManager().getDeviceForCurrentThread());
allocationsMap.put(allocId, point);

View File

@ -68,11 +68,7 @@ public class CudaAffinityManager extends BasicAffinityManager {
*/
@Override
public Integer getDeviceForCurrentThread() {
val id = NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
if (!affinityMap.containsKey(Thread.currentThread().getId()))
affinityMap.put(Thread.currentThread().getId(), id);
return id;
return NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
}
/**
@ -205,7 +201,6 @@ public class CudaAffinityManager extends BasicAffinityManager {
int currentDeviceId = getDeviceForCurrentThread();
if (currentDeviceId != deviceId.intValue()) {
Nd4j.getMemoryManager().releaseCurrentContext();
unsafeSetDevice(deviceId);
}
@ -215,7 +210,6 @@ public class CudaAffinityManager extends BasicAffinityManager {
INDArray result = Nd4j.createArrayFromShapeBuffer(newDataBuffer, newShapeBuffer);
if (currentDeviceId != deviceId.intValue()) {
Nd4j.getMemoryManager().releaseCurrentContext();
unsafeSetDevice(currentDeviceId);
}
@ -238,7 +232,6 @@ public class CudaAffinityManager extends BasicAffinityManager {
int currentDeviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
if (currentDeviceId != deviceId) {
Nd4j.getMemoryManager().releaseCurrentContext();
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
}
@ -246,7 +239,6 @@ public class CudaAffinityManager extends BasicAffinityManager {
AtomicAllocator.getInstance().memcpy(dstBuffer, buffer);
if (currentDeviceId != deviceId) {
Nd4j.getMemoryManager().releaseCurrentContext();
Nd4j.getAffinityManager().unsafeSetDevice(currentDeviceId);
}

View File

@ -188,13 +188,15 @@ public class SynchronousFlowController implements FlowController {
}
if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
((JCublasNDArray) result).setShapeInfoDataBuffer(
Nd4j.getConstantHandler().relocateConstantSpace(result.shapeInfoDataBuffer()));
((JCublasNDArray) result).setShapeInfoDataBuffer(Nd4j.getExecutioner().createShapeInfo(result.shape(), result.stride(), result.elementWiseStride(), result.ordering(), result.dataType(), result.isEmpty()));
}
allocator.getAllocationPoint(result).setCurrentContext(context);
}
if (operands == null)
return context;
for (INDArray operand : operands) {
if (operand == null || operand.isEmpty())
continue;
@ -213,8 +215,7 @@ public class SynchronousFlowController implements FlowController {
}
if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
((JCublasNDArray) operand).setShapeInfoDataBuffer(
Nd4j.getConstantHandler().relocateConstantSpace(operand.shapeInfoDataBuffer()));
((JCublasNDArray) operand).setShapeInfoDataBuffer(Nd4j.getExecutioner().createShapeInfo(operand.shape(), operand.stride(), operand.elementWiseStride(), operand.ordering(), operand.dataType(), operand.isEmpty()));
}
prepareDelayedMemory(operand);

View File

@ -819,10 +819,10 @@ public class CudaZeroHandler implements MemoryHandler {
val ohPtr = dstPoint.getHostPointer();
// FIXME: cross-thread access, might cause problems
if (!dstPoint.isActualOnHostSide())
if (dstPoint.getHostPointer() != null && !dstPoint.isActualOnHostSide())
AtomicAllocator.getInstance().synchronizeHostData(buffer);
if (!dstPoint.isActualOnHostSide())
if (dstPoint.getHostPointer() != null && !dstPoint.isActualOnHostSide())
throw new RuntimeException("Buffer synchronization failed");
if (buffer.isAttached() || dstPoint.isAttached()) {
@ -832,10 +832,14 @@ public class CudaZeroHandler implements MemoryHandler {
if (workspace == null) {
// if we're out of workspace, we should mark our buffer as detached, so gc will pick it up eventually
val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false);
// host part is optional
if (dstPoint.getHostPointer() != null) {
val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false);
dstPoint.getPointers().setHostPointer(pairH.getHostPointer());
}
val pairD = alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
dstPoint.getPointers().setDevicePointer(pairD.getDevicePointer());
dstPoint.getPointers().setHostPointer(pairH.getHostPointer());
//log.info("New host pointer: {}; Old host pointer: {}", dstPoint.getHostPointer().address(), ohPtr.address());
@ -869,7 +873,11 @@ public class CudaZeroHandler implements MemoryHandler {
Nd4j.getMemoryManager().memcpy(nBuffer, buffer);
dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer());
dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer());
if (dstPoint.getHostPointer() != null) {
dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer());
}
dstPoint.setDeviceId(deviceId);
dstPoint.tickDeviceRead();
@ -885,6 +893,17 @@ public class CudaZeroHandler implements MemoryHandler {
throw new RuntimeException("Can't relocateObject() for constant buffer");
} else {
// log.info("Free relocateObject: deviceId: {}, pointer: {}", deviceId, dstPoint.getPointers().getDevicePointer().address());
val context = getCudaContext();
if (dstPoint.getHostPointer() == null) {
((BaseCudaDataBuffer) buffer).lazyAllocateHostPointer();
if (nativeOps.memcpyAsync(dstPoint.getHostPointer(), dstPoint.getDevicePointer(),
buffer.length() * buffer.getElementSize(), 2, context.getSpecialStream()) == 0)
throw new ND4JIllegalStateException("memcpyAsync failed");
context.syncSpecialStream();
}
memoryProvider.free(dstPoint);
deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape()));
@ -893,7 +912,6 @@ public class CudaZeroHandler implements MemoryHandler {
val profD = PerformanceTracker.getInstance().helperStartTransaction();
CudaContext context = getCudaContext();
if (nativeOps.memcpyAsync(dstPoint.getDevicePointer(), dstPoint.getHostPointer(),
buffer.length() * buffer.getElementSize(), 1, context.getSpecialStream()) == 0)
throw new ND4JIllegalStateException("memcpyAsync failed");

View File

@ -17,6 +17,7 @@
package org.nd4j.linalg.jcublas;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.nd4j.jita.allocator.enums.AllocationStatus;
import org.nd4j.jita.allocator.enums.CudaConstants;
@ -54,7 +55,7 @@ import java.util.concurrent.atomic.AtomicLong;
* @author Adam Gibson
* @author raver119@gmail.com
*/
@Slf4j
public class JCublasNDArray extends BaseNDArray {
@ -737,14 +738,12 @@ public class JCublasNDArray extends BaseNDArray {
if (!this.isView()) {
Nd4j.getExecutioner().commit();
DataBuffer buffer = Nd4j.createBuffer(this.lengthLong(), false);
val buffer = Nd4j.createBuffer(this.dataType(), this.lengthLong(), false);
AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
// CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
@ -764,14 +763,12 @@ public class JCublasNDArray extends BaseNDArray {
PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointDst.getNumberOfBytes(), direction);
if (pointDst.getDeviceId() != Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId()) {
//log.info("Swapping [{}] -> [{}]", pointDst.getDeviceId(), Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId());
pointDst.setDeviceId(Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId());
}
copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
// tag buffer as valid on device side
pointDst.tickHostRead();
pointDst.tickDeviceWrite();
AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc);

View File

@ -25,6 +25,7 @@ import org.bytedeco.javacpp.indexer.*;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.buffer.LongBuffer;
import org.nd4j.linalg.api.buffer.Utf8Buffer;
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
@ -92,6 +93,8 @@ public class CudaDataBufferFactory implements DataBufferFactory {
return new CudaByteDataBuffer(underlyingBuffer, length, offset);
case BOOL:
return new CudaBoolDataBuffer(underlyingBuffer, length, offset);
case UTF8:
return new Utf8Buffer(underlyingBuffer, length, offset);
default:
throw new ND4JIllegalStateException("Unknown data buffer type: " + underlyingBuffer.dataType().toString());
}

View File

@ -233,7 +233,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension)
+ " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
val context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
lastOp.set(op.opName());
@ -2491,6 +2491,11 @@ public class CudaExecutioner extends DefaultOpExecutioner {
nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer());
for (val arr:op.outputArguments())
AtomicAllocator.getInstance().registerAction(ctx, arr);
AtomicAllocator.getInstance().registerAction(ctx, null, op.inputArguments());
profilingConfigurableHookOut(op, st);
if (context.getOutputArrays().isEmpty())

View File

@ -81,10 +81,8 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
@Override
public void setInputArray(int index, @NonNull INDArray array) {
// FIXME: remove
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE);
val ctx = AtomicAllocator.getInstance().getFlowController().prepareAction(null, array);
val ctx = AtomicAllocator.getInstance().getDeviceContext();
nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
super.setInputArray(index, array);
@ -92,9 +90,8 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
@Override
public void setOutputArray(int index, @NonNull INDArray array) {
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE);
val ctx = AtomicAllocator.getInstance().getFlowController().prepareAction(array, null);
val ctx = AtomicAllocator.getInstance().getDeviceContext();
nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
super.setOutputArray(index, array);

View File

@ -9898,6 +9898,72 @@ public static final int PREALLOC_SIZE = 33554432;
// #endif //LIBND4J_OPREGISTRATOR_H
// Parsed from execution/ContextBuffers.h
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
// #ifndef LIBND4J_CONTEXTBUFFERS_H
// #define LIBND4J_CONTEXTBUFFERS_H
// #include <dll.h>
// #include <pointercast.h>
@Namespace("nd4j") @NoOffset public static class ContextBuffers extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ContextBuffers(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public ContextBuffers(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public ContextBuffers position(long position) {
return (ContextBuffers)super.position(position);
}
public ContextBuffers() { super((Pointer)null); allocate(); }
private native void allocate();
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); }
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/);
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); }
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer);
public native Pointer reductionBuffer();
public native Pointer scalarBuffer();
public native Pointer allocationBuffer();
public native Pointer execStream();
public native Pointer specialStream();
public native void setReductionBuffer(Pointer pointer);
public native void setScalarBuffer(Pointer pointer);
public native void setAllocationBuffer(Pointer pointer);
public native void triggerOwnership(@Cast("bool") boolean isOwner);
public native int deviceId();
}
// #endif //DEV_TESTS_CONTEXTBUFFERS_H
// Parsed from execution/LaunchContext.h
/*******************************************************************************
@ -9971,6 +10037,9 @@ public static final int PREALLOC_SIZE = 33554432;
public static native LaunchContext defaultContext();
public static native void swapContextBuffers(@ByRef ContextBuffers buffers);
}

View File

@ -76,6 +76,7 @@ import org.bytedeco.javacpp.tools.InfoMapper;
"ops/declarable/BooleanOp.h",
"ops/declarable/LogicOp.h",
"ops/declarable/OpRegistrator.h",
"execution/ContextBuffers.h",
"execution/LaunchContext.h",
"array/ShapeDescriptor.h",
"array/TadDescriptor.h",

View File

@ -22808,6 +22808,72 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
// #endif
// Parsed from execution/ContextBuffers.h
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
// #ifndef LIBND4J_CONTEXTBUFFERS_H
// #define LIBND4J_CONTEXTBUFFERS_H
// #include <dll.h>
// #include <pointercast.h>
@Namespace("nd4j") @NoOffset public static class ContextBuffers extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public ContextBuffers(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public ContextBuffers(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public ContextBuffers position(long position) {
return (ContextBuffers)super.position(position);
}
public ContextBuffers() { super((Pointer)null); allocate(); }
private native void allocate();
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); }
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/);
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); }
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer);
public native Pointer reductionBuffer();
public native Pointer scalarBuffer();
public native Pointer allocationBuffer();
public native Pointer execStream();
public native Pointer specialStream();
public native void setReductionBuffer(Pointer pointer);
public native void setScalarBuffer(Pointer pointer);
public native void setAllocationBuffer(Pointer pointer);
public native void triggerOwnership(@Cast("bool") boolean isOwner);
public native int deviceId();
}
// #endif //DEV_TESTS_CONTEXTBUFFERS_H
// Parsed from execution/LaunchContext.h
/*******************************************************************************
@ -22872,6 +22938,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public static native LaunchContext defaultContext();
public static native void swapContextBuffers(@ByRef ContextBuffers buffers);
}

View File

@ -100,6 +100,7 @@ import java.util.Scanner;
"ops/declarable/headers/bitwise.h",
"ops/declarable/headers/loss.h",
"ops/declarable/headers/datatypes.h",
"execution/ContextBuffers.h",
"execution/LaunchContext.h",
"array/ShapeDescriptor.h",
"array/TadDescriptor.h",

View File

@ -198,27 +198,43 @@ public class SpecialTests extends BaseNd4jTest {
val list = new CopyOnWriteArrayList<INDArray>();
val threads = new ArrayList<Thread>();
for (int e = 0; e< Nd4j.getAffinityManager().getNumberOfDevices(); e++) {
val devices = Nd4j.getAffinityManager().getNumberOfDevices();
for (int e = 0; e < devices; e++) {
val f = e;
val t = new Thread(new Runnable() {
@Override
public void run() {
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
log.info("Current device: {}", deviceId);
for (int i = 0; i < 10; i++) {
list.add(Nd4j.create(100, 100).assign(1.0f));
val ar = Nd4j.create(100, 100).assign(1.0f);
assertEquals(deviceId, Nd4j.getAffinityManager().getDeviceForArray(ar));
list.add(ar);
Nd4j.getExecutioner().commit();
}
}
});
t.start();
t.join();
threads.add(t);
log.info("------------------------");
}
for (val t:threads)
t.join();
for (val a:list)
assertEquals(1.0f, a.meanNumber().floatValue(), 1e-5);
for (val a:list) {
val device = Nd4j.getAffinityManager().getDeviceForArray(a);
try {
assertEquals(1.0f, a.meanNumber().floatValue(), 1e-5);
} catch (Exception e) {
log.error("Failed for array from device [{}]", device);
throw e;
}
}
}
@Test