[WIP] cross-device migrations (#134)
* two more tests fixed Signed-off-by: raver119 <raver119@gmail.com> * CUDA device afinity tweaks Signed-off-by: raver119 <raver119@gmail.com> * minor tweaks Signed-off-by: raver119 <raver119@gmail.com> * prepareAction/registerAction for CustomOps Signed-off-by: raver119 <raver119@gmail.com> * lazy allocate host bufer before relocation Signed-off-by: raver119 <raver119@gmail.com> * one special test for migration in cpp Signed-off-by: raver119 <raver119@gmail.com> * tests update for msvc Signed-off-by: raver119 <raver119@gmail.com> * logging Signed-off-by: raver119 <raver119@gmail.com> * stick to old col2im impl Signed-off-by: raver119 <raver119@gmail.com> * cudaStreams reorganization Signed-off-by: raver119 <raver119@gmail.com> * buffer size fix Signed-off-by: raver119 <raver119@gmail.com> * c++ data migration Signed-off-by: raver119 <raver119@gmail.com> * fix CropAndResize test Signed-off-by: raver119 <raver119@gmail.com> * - minor improvment Signed-off-by: Yurii <yurii@skymind.io>master
parent
23c8738d4a
commit
269d508ba5
|
@ -39,6 +39,7 @@
|
|||
#include <ShapeDescriptor.h>
|
||||
#include <helpers/ConstantShapeHelper.h>
|
||||
#include <array/DataBuffer.h>
|
||||
#include <AffinityManager.h>
|
||||
|
||||
|
||||
namespace nd4j {
|
||||
|
@ -143,6 +144,11 @@ namespace nd4j {
|
|||
*/
|
||||
nd4j::DataType _dataType = FLOAT32;
|
||||
|
||||
/**
|
||||
* deviceID where this NDArray belongs to
|
||||
*/
|
||||
int _deviceId = AffinityManager::currentDeviceId();
|
||||
|
||||
template<typename T>
|
||||
std::string toStringValue(T value);
|
||||
|
||||
|
|
|
@ -55,7 +55,19 @@ void* NDArray::getPlatformBuffer() const { return getSpecialBuffer(); }
|
|||
Nd4jLong* NDArray::getPlatformShapeInfo() const { return getSpecialShapeInfo(); }
|
||||
Nd4jLong* NDArray::platformShapeInfo() { return specialShapeInfo(); }
|
||||
|
||||
void NDArray::syncToDevice() const { _buffer->syncToSpecial(); }
|
||||
void NDArray::syncToDevice() const {
|
||||
auto currentDeviceId = AffinityManager::currentDeviceId();
|
||||
if (currentDeviceId != _deviceId) {
|
||||
// first of all we update shapeInfo
|
||||
const_cast<NDArray*>(this)->setShapeInfo(this->getShapeInfo());
|
||||
|
||||
// now we actually migrate data buffer
|
||||
_buffer->migrate();
|
||||
}
|
||||
|
||||
_buffer->syncToSpecial();
|
||||
}
|
||||
|
||||
void NDArray::syncToHost() const { _buffer->syncToPrimary(getContext()); }
|
||||
void NDArray::tickWriteHost() const { _buffer->writePrimary(); }
|
||||
void NDArray::tickWriteDevice() const { _buffer->writeSpecial(); }
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include <AffinityManager.h>
|
||||
|
||||
#include <exceptions/datatype_exception.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <helpers/CudaLaunchHelper.h>
|
||||
// FIXME: we need cuda-specific implementations
|
||||
#include <GraphExecutioner.h>
|
||||
|
@ -965,11 +966,7 @@ int registerEvent(Nd4jPointer event, Nd4jPointer stream) {
|
|||
}
|
||||
|
||||
int setDevice(int deviceId) {
|
||||
auto dZ = cudaSetDevice(deviceId);
|
||||
checkCudaErrors(dZ);
|
||||
if (dZ != 0)
|
||||
throw std::runtime_error("cudaSetDevice(...) failed");
|
||||
|
||||
AffinityManager::setCurrentDevice(deviceId);
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
@ -1024,16 +1021,15 @@ Nd4jLong getDeviceTotalMemory(int device) {
|
|||
}
|
||||
|
||||
int memcpySync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
||||
|
||||
return memcpyAsync(dst, src, size, flags, reserved);
|
||||
}
|
||||
|
||||
int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4jPointer reserved) {
|
||||
cudaStream_t *pStream = reinterpret_cast<cudaStream_t *>(reserved);
|
||||
auto pStream = reinterpret_cast<cudaStream_t *>(reserved);
|
||||
|
||||
cudaMemcpyKind kind;
|
||||
|
||||
DEBUG_KERNEL(pStream, 0);
|
||||
//nd4j::DebugHelper::checkErrorCode(pStream, "Preliminary sync failed");
|
||||
|
||||
switch (flags) {
|
||||
case 0: {
|
||||
|
@ -1047,25 +1043,22 @@ int memcpyAsync(Nd4jPointer dst, Nd4jPointer src, Nd4jLong size, int flags, Nd4j
|
|||
case 2: {
|
||||
kind = cudaMemcpyDeviceToHost;
|
||||
}
|
||||
break;
|
||||
case 3: {
|
||||
kind = cudaMemcpyDeviceToDevice;
|
||||
}
|
||||
kind = cudaMemcpyDeviceToDevice;
|
||||
}
|
||||
break;
|
||||
default: {
|
||||
|
||||
printf("UNDEFINED MEMCPY!\n");
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw nd4j::cuda_exception::build("UNDEFINED MEMCPY!\n", 119);
|
||||
}
|
||||
|
||||
cudaError_t dZ = cudaMemcpyAsync(reinterpret_cast<void *>(dst), const_cast<const void *>(reinterpret_cast<void *>(src)), static_cast<size_t>(size), kind, *pStream);
|
||||
auto dZ = cudaMemcpyAsync(reinterpret_cast<void *>(dst), const_cast<const void *>(reinterpret_cast<void *>(src)), static_cast<size_t>(size), kind, *pStream);
|
||||
//auto dZ = cudaMemcpy(reinterpret_cast<void *>(dst), const_cast<const void *>(reinterpret_cast<void *>(src)), static_cast<size_t>(size), kind);
|
||||
if (dZ != 0) {
|
||||
checkCudaErrors(dZ);
|
||||
printf("Failed on [%lu] -> [%lu], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast<int>(dZ));
|
||||
printf("Failed on [%lu] -> [%lu], size: [%i], direction: [%i], dZ: [%i]\n", src, dst, size, flags, static_cast<int>(dZ));
|
||||
fflush(stdout);
|
||||
fflush(stderr);
|
||||
throw std::runtime_error("cudaMemcpyAsync(...) failed");
|
||||
//return 0L;
|
||||
throw nd4j::cuda_exception::build("cudaMemcpyAsync(...) failed", dZ);
|
||||
}
|
||||
|
||||
return 1;
|
||||
|
@ -3256,7 +3249,7 @@ void tryPointer(Nd4jPointer extra, Nd4jPointer p, int len) {
|
|||
auto e = cudaStreamSynchronize(stream);
|
||||
|
||||
if (e != 0)
|
||||
throw std::runtime_error("tryPointer failed");
|
||||
throw nd4j::cuda_exception::build("tryPointer failed", e);
|
||||
|
||||
cudaStreamDestroy(stream);
|
||||
}
|
||||
|
|
|
@ -105,6 +105,8 @@ class ND4J_EXPORT DataBuffer {
|
|||
bool isPrimaryActual() const;
|
||||
bool isSpecialActual() const;
|
||||
|
||||
void migrate();
|
||||
|
||||
template <typename T> FORCEINLINE T* primaryAsT();
|
||||
template <typename T> FORCEINLINE T* specialAsT();
|
||||
|
||||
|
|
|
@ -96,6 +96,11 @@ void DataBuffer::allocateSpecial() {
|
|||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void DataBuffer::migrate() {
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void DataBuffer::writePrimary() const { }
|
||||
void DataBuffer::writeSpecial() const { }
|
||||
|
|
|
@ -173,6 +173,22 @@ void DataBuffer::setToZeroBuffers(const bool both) {
|
|||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void DataBuffer::migrate() {
|
||||
memory::Workspace* newWorkspace = nullptr;
|
||||
void* newBuffer;
|
||||
ALLOCATE_SPECIAL(newBuffer, newWorkspace, getLenInBytes(), int8_t);
|
||||
cudaMemcpy(newBuffer, _specialBuffer, getLenInBytes(), cudaMemcpyDeviceToDevice);
|
||||
|
||||
if (_isOwnerSpecial) {
|
||||
// now we're releasing original buffer
|
||||
RELEASE_SPECIAL(_specialBuffer, _workspace);
|
||||
}
|
||||
|
||||
_isOwnerSpecial = true;
|
||||
_specialBuffer = newBuffer;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void DataBuffer::writePrimary() const { _writePrimary = ++_counter; }
|
||||
void DataBuffer::writeSpecial() const { _writeSpecial = ++_counter; }
|
||||
|
|
|
@ -27,10 +27,12 @@
|
|||
namespace nd4j {
|
||||
class ND4J_EXPORT ContextBuffers {
|
||||
private:
|
||||
void* _reductionPointer;
|
||||
void* _scalarPointer;
|
||||
void* _allocationPointer;
|
||||
bool _allocated = true;
|
||||
void* _reductionPointer = nullptr;
|
||||
void* _scalarPointer = nullptr;
|
||||
void* _allocationPointer = nullptr;
|
||||
void* _execStream = nullptr;
|
||||
void* _specialStream = nullptr;
|
||||
bool _allocated = false;
|
||||
|
||||
int _deviceId = -1;
|
||||
|
||||
|
@ -44,6 +46,9 @@ namespace nd4j {
|
|||
void* scalarBuffer();
|
||||
void* allocationBuffer();
|
||||
|
||||
void* execStream();
|
||||
void* specialStream();
|
||||
|
||||
void setReductionBuffer(void* pointer);
|
||||
void setScalarBuffer(void* pointer);
|
||||
void setAllocationBuffer(void* pointer);
|
||||
|
|
|
@ -52,8 +52,6 @@ class ND4J_EXPORT LaunchContext {
|
|||
|
||||
#ifndef __JAVACPP_HACK__
|
||||
|
||||
cudaStream_t* _cudaStream = nullptr;
|
||||
cudaStream_t* _cudaSpecialStream = nullptr;
|
||||
void* _cublasHandle = nullptr;
|
||||
void* _cusolverHandle = nullptr;
|
||||
|
||||
|
@ -102,6 +100,9 @@ class ND4J_EXPORT LaunchContext {
|
|||
|
||||
static LaunchContext* defaultContext();
|
||||
|
||||
|
||||
static void swapContextBuffers(ContextBuffers &buffers);
|
||||
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -71,4 +71,12 @@ namespace nd4j {
|
|||
int ContextBuffers::deviceId() {
|
||||
return _deviceId;
|
||||
}
|
||||
}
|
||||
|
||||
void* ContextBuffers::execStream() {
|
||||
return _execStream;
|
||||
}
|
||||
|
||||
void* ContextBuffers::specialStream() {
|
||||
return _specialStream;
|
||||
}
|
||||
}
|
|
@ -53,4 +53,8 @@ namespace nd4j {
|
|||
// return context for current device
|
||||
return LaunchContext::_contexts[0].get();
|
||||
}
|
||||
|
||||
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
|
||||
//
|
||||
}
|
||||
}
|
|
@ -21,6 +21,7 @@
|
|||
#include <logger.h>
|
||||
#include <execution/AffinityManager.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <LaunchContext.h>
|
||||
|
||||
thread_local int globalThreadToDevice = -1;
|
||||
|
||||
|
@ -98,10 +99,13 @@ namespace nd4j {
|
|||
if (res != 0)
|
||||
throw cuda_exception::build("cudaSetDevice failed", res);
|
||||
|
||||
auto previousDeviceId = globalThreadToDevice;
|
||||
|
||||
// update thread-device affinity
|
||||
globalThreadToDevice = deviceId;
|
||||
|
||||
// TODO: update context buffers?
|
||||
ContextBuffers newBuffers;
|
||||
LaunchContext::swapContextBuffers(newBuffers);
|
||||
}
|
||||
|
||||
std::atomic<int> AffinityManager::_lastDevice;// = std::atomic<int>(initialV);
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
//
|
||||
|
||||
#include <execution/ContextBuffers.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <logger.h>
|
||||
#include <AffinityManager.h>
|
||||
|
||||
|
@ -45,6 +46,18 @@ namespace nd4j {
|
|||
|
||||
if (_allocationPointer != nullptr)
|
||||
cudaFree(_reductionPointer);
|
||||
|
||||
auto _cudaStream = reinterpret_cast<cudaStream_t*>(_execStream);
|
||||
auto _cudaSpecialStream = reinterpret_cast<cudaStream_t*>(_specialStream);
|
||||
|
||||
cudaStreamSynchronize(*_cudaStream);
|
||||
cudaStreamSynchronize(*_cudaSpecialStream);
|
||||
|
||||
cudaStreamDestroy(*_cudaStream);
|
||||
cudaStreamDestroy(*_cudaSpecialStream);
|
||||
|
||||
delete _cudaStream;
|
||||
delete _cudaSpecialStream;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,6 +83,19 @@ namespace nd4j {
|
|||
if (res != 0)
|
||||
throw std::runtime_error("_allocationPointer allocation failed");
|
||||
|
||||
_execStream = new cudaStream_t();
|
||||
_specialStream = new cudaStream_t();
|
||||
if (nullptr == _execStream || nullptr == _specialStream)
|
||||
throw std::runtime_error("Failed to allocate memory for new CUDA stream");
|
||||
|
||||
res = cudaStreamCreate(reinterpret_cast<cudaStream_t*>(_execStream));
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("Failed to create default CUDA stream with launch context", res);
|
||||
|
||||
res = cudaStreamCreate(reinterpret_cast<cudaStream_t*>(_specialStream));
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("Failed to create special CUDA stream with launch context", res);
|
||||
|
||||
_allocated = true;
|
||||
}
|
||||
|
||||
|
@ -113,4 +139,18 @@ namespace nd4j {
|
|||
int ContextBuffers::deviceId() {
|
||||
return _deviceId;
|
||||
}
|
||||
|
||||
void* ContextBuffers::execStream() {
|
||||
if (_execStream == nullptr)
|
||||
initialize();
|
||||
|
||||
return _execStream;
|
||||
}
|
||||
|
||||
void* ContextBuffers::specialStream() {
|
||||
if (_specialStream == nullptr)
|
||||
initialize();
|
||||
|
||||
return _specialStream;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,8 +35,8 @@ namespace nd4j {
|
|||
////////////////////////////////////////////////////////////////////////
|
||||
LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCudaStream, void* reductionPointer, void* scalarPointer, int* allocationPointer) {
|
||||
|
||||
_cudaStream = cudaStream;
|
||||
_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream;
|
||||
//_cudaStream = cudaStream;
|
||||
//_cudaSpecialStream = &specialCudaStream; // ideal is = new cudaStream_t; *_cudaSpecialStream = specialCudaStream;
|
||||
//_reductionPointer = reductionPointer;
|
||||
//_scalarPointer = scalarPointer;
|
||||
//_allocationPointer = allocationPointer;
|
||||
|
@ -46,14 +46,7 @@ LaunchContext::LaunchContext(cudaStream_t *cudaStream, cudaStream_t& specialCuda
|
|||
|
||||
LaunchContext::~LaunchContext() {
|
||||
if (_isAllocated) {
|
||||
cudaStreamSynchronize(*_cudaStream);
|
||||
cudaStreamSynchronize(*_cudaSpecialStream);
|
||||
|
||||
cudaStreamDestroy(*_cudaStream);
|
||||
cudaStreamDestroy(*_cudaSpecialStream);
|
||||
|
||||
delete _cudaStream;
|
||||
delete _cudaSpecialStream;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -64,32 +57,16 @@ LaunchContext::LaunchContext() {
|
|||
_deviceID = 0;
|
||||
|
||||
_isAllocated = true;
|
||||
_cudaStream = new cudaStream_t();
|
||||
_cudaSpecialStream = new cudaStream_t();
|
||||
if (nullptr == _cudaStream || nullptr == _cudaSpecialStream)
|
||||
throw std::runtime_error("Failed to allocate memory for new CUDA stream");
|
||||
|
||||
cudaError_t err = cudaStreamCreate(_cudaStream);
|
||||
if (err != 0)
|
||||
throw cuda_exception::build("Failed to create default CUDA stream with launch context", err);
|
||||
|
||||
err = cudaStreamCreate(_cudaSpecialStream);
|
||||
if (err != 0)
|
||||
throw cuda_exception::build("Failed to create special CUDA stream with launch context", err);
|
||||
|
||||
_cublasHandle = CublasHelper::getInstance()->handle();
|
||||
|
||||
_cusolverHandle = CublasHelper::getInstance()->solver();
|
||||
|
||||
auto res = cudaStreamSynchronize(*_cudaStream);
|
||||
if (res != 0)
|
||||
throw cuda_exception::build("Initial sync failed", res);
|
||||
}
|
||||
|
||||
LaunchContext::LaunchContext(Nd4jPointer cudaStream, Nd4jPointer reductionPointer, Nd4jPointer scalarPointer, Nd4jPointer allocationPointer) {
|
||||
_isAllocated = false;
|
||||
_cudaStream = reinterpret_cast<cudaStream_t*>(cudaStream);
|
||||
_cudaSpecialStream = reinterpret_cast<cudaStream_t*>(cudaStream);
|
||||
//_cudaStream = reinterpret_cast<cudaStream_t*>(cudaStream);
|
||||
// _cudaSpecialStream = reinterpret_cast<cudaStream_t*>(cudaStream);
|
||||
//_reductionPointer = reductionPointer;
|
||||
//_scalarPointer = scalarPointer;
|
||||
//_allocationPointer = reinterpret_cast<int *>(allocationPointer);
|
||||
|
@ -148,11 +125,11 @@ LaunchContext::LaunchContext() {
|
|||
};
|
||||
|
||||
cudaStream_t* LaunchContext::getCudaStream() const {
|
||||
return _cudaStream;
|
||||
return reinterpret_cast<cudaStream_t*>(contextBuffers.execStream());
|
||||
};
|
||||
|
||||
cudaStream_t* LaunchContext::getCudaSpecialStream() const {
|
||||
return _cudaSpecialStream;
|
||||
return reinterpret_cast<cudaStream_t*>(contextBuffers.specialStream());;
|
||||
};
|
||||
|
||||
|
||||
|
@ -169,14 +146,18 @@ LaunchContext::LaunchContext() {
|
|||
};
|
||||
|
||||
void LaunchContext::setCudaStream(cudaStream_t* cudaStream) {
|
||||
_cudaStream = cudaStream;
|
||||
//_cudaStream = cudaStream;
|
||||
};
|
||||
|
||||
void LaunchContext::setCudaSpecialStream(cudaStream_t* cudaStream) {
|
||||
_cudaSpecialStream = cudaStream;
|
||||
//_cudaSpecialStream = cudaStream;
|
||||
};
|
||||
|
||||
void LaunchContext::setCublasHandle(void *handle) {
|
||||
_cublasHandle = handle;
|
||||
};
|
||||
|
||||
void LaunchContext::swapContextBuffers(ContextBuffers &buffers) {
|
||||
contextBuffers = buffers;
|
||||
};
|
||||
}
|
|
@ -22,12 +22,13 @@
|
|||
#include <exceptions/cuda_exception.h>
|
||||
#include <ShapeDescriptor.h>
|
||||
#include <ShapeBuilders.h>
|
||||
#include <AffinityManager.h>
|
||||
#include <ConstantHelper.h>
|
||||
|
||||
namespace nd4j {
|
||||
|
||||
ConstantShapeHelper::ConstantShapeHelper() {
|
||||
auto numDevices = ConstantHelper::getNumberOfDevices();
|
||||
auto numDevices = AffinityManager::numberOfDevices();
|
||||
|
||||
_cache.resize(numDevices);
|
||||
for (int e = 0; e < numDevices; e++) {
|
||||
|
@ -54,7 +55,7 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
ConstantDataBuffer& ConstantShapeHelper::bufferForShapeInfo(const ShapeDescriptor &descriptor) {
|
||||
int deviceId = ConstantHelper::getCurrentDevice();
|
||||
int deviceId = AffinityManager::currentDeviceId();
|
||||
|
||||
_mutex.lock();
|
||||
|
||||
|
@ -83,7 +84,7 @@ namespace nd4j {
|
|||
|
||||
bool ConstantShapeHelper::checkBufferExistenceForShapeInfo(ShapeDescriptor &descriptor) {
|
||||
bool result;
|
||||
auto deviceId = ConstantHelper::getCurrentDevice();
|
||||
auto deviceId = AffinityManager::currentDeviceId();
|
||||
_mutex.lock();
|
||||
|
||||
if (_cache[deviceId].count(descriptor) == 0)
|
||||
|
|
|
@ -21,13 +21,14 @@
|
|||
#include "../ConstantTadHelper.h"
|
||||
#include <TAD.h>
|
||||
#include <ConstantHelper.h>
|
||||
#include <AffinityManager.h>
|
||||
#include <exceptions/cuda_exception.h>
|
||||
#include <execution/LaunchContext.h>
|
||||
#include <ShapeUtils.h>
|
||||
|
||||
namespace nd4j {
|
||||
ConstantTadHelper::ConstantTadHelper() {
|
||||
auto numDevices = ConstantHelper::getNumberOfDevices();
|
||||
auto numDevices = AffinityManager::numberOfDevices();
|
||||
|
||||
for (int e = 0; e < numDevices; e++) {
|
||||
std::map<TadDescriptor, TadPack> pack;
|
||||
|
@ -61,7 +62,7 @@ namespace nd4j {
|
|||
}
|
||||
|
||||
TadPack& ConstantTadHelper::tadForDimensions(TadDescriptor &descriptor) {
|
||||
const int deviceId = ConstantHelper::getCurrentDevice();
|
||||
const int deviceId = AffinityManager::currentDeviceId();
|
||||
|
||||
_mutex.lock();
|
||||
|
||||
|
|
|
@ -184,8 +184,8 @@ static void col2imCudaLauncher(const int blocksPerGrid, const int threadsPerBloc
|
|||
void* image, const Nd4jLong* imShapeInfo,
|
||||
const int sH, const int sW, const int pH, const int pW, const int dH, const int dW) {
|
||||
|
||||
// col2imCuda2<T><<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW);
|
||||
col2imCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW);
|
||||
col2imCuda2<T><<<512, 512, 1024, *stream>>>(columns, image, colShapeInfo, imShapeInfo, sH, sW, pH, pW, dH, dW);
|
||||
//col2imCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(columns, colShapeInfo, image, imShapeInfo, sH, sW, pH, pW, dH, dW);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -135,8 +135,9 @@ namespace helpers {
|
|||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (int e = tid; e < len; e += step) {
|
||||
if (output[shape::getIndexOffset(e, outputShape, len)] != T(0.))
|
||||
output[shape::getIndexOffset(e, outputShape, len)] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue);
|
||||
const auto zOffset = shape::getIndexOffset(e, outputShape, len);
|
||||
if (output[zOffset] != T(0.))
|
||||
output[zOffset] = T(input[shape::getIndexOffset(e, gradOutShape, len)] / probValue);
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1050,14 +1050,14 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test1) {
|
|||
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
|
||||
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
|
||||
|
||||
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{ 0.226, 0.343, 0.46 , 0.577, 1.172, 1.46 , 1.748, 2.036, 1.892, 2.288, 2.684, 3.08 , 1.284, 1.581, 1.878, 2.175, 4.458, 5.133, 5.808, 6.483, 6.186, 7.023, 7.86 , 8.697,
|
||||
3.39 , 3.93 , 4.47 , 5.01 , 9.642, 10.803, 11.964, 13.125,11.37 , 12.693, 14.016, 15.339, 5.266, 5.707, 6.148, 6.589,12.98 , 13.916, 14.852, 15.788,14.564, 15.608, 16.652, 17.696,
|
||||
3.25 , 4.015, 4.78 , 5.545, 9.812, 11.396, 12.98 , 14.564,10.532, 12.224, 13.916, 15.608, 9.708, 10.977, 12.246, 13.515,25.194, 27.813, 30.432, 33.051,26.922, 29.703, 32.484, 35.265,
|
||||
11.814, 13.326, 14.838, 16.35 ,30.378, 33.483, 36.588, 39.693,32.106, 35.373, 38.64 , 41.907,13.474, 14.563, 15.652, 16.741,31.988, 34.22 , 36.452, 38.684,33.572, 35.912, 38.252, 40.592});
|
||||
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{ 0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f,
|
||||
3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f,11.37f, 12.693f, 14.016f, 15.339f, 5.266f, 5.707f, 6.148f, 6.589f,12.98f, 13.916f, 14.852f, 15.788f,14.564f, 15.608f, 16.652f, 17.696f,
|
||||
3.25f, 4.015f, 4.78f, 5.545f, 9.812f, 11.396f, 12.98f, 14.564f,10.532f, 12.224f, 13.916f, 15.608f, 9.708f, 10.977f, 12.246f, 13.515f,25.194f, 27.813f, 30.432f, 33.051f,26.922f, 29.703f, 32.484f, 35.265f,
|
||||
11.814f, 13.326f, 14.838f, 16.35f,30.378f, 33.483f, 36.588f, 39.693f,32.106f, 35.373f, 38.64f, 41.907f,13.474f, 14.563f, 15.652f, 16.741f,31.988f, 34.22f, 36.452f, 38.684f,33.572f, 35.912f, 38.252f, 40.592f});
|
||||
|
||||
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC},{14.4 , 14.76, 15.12,14.4 , 14.76, 15.12,14.4 , 14.76, 15.12,14.4 , 14.76, 15.12, 9.24, 9.48, 9.72, 9.24, 9.48, 9.72, 9.24, 9.48, 9.72, 9.24, 9.48, 9.72,
|
||||
17.04, 17.52, 18. ,17.04, 17.52, 18. ,17.04, 17.52, 18. ,17.04, 17.52, 18. ,10.88, 11.2 , 11.52,10.88, 11.2 , 11.52,10.88, 11.2 , 11.52,10.88, 11.2 , 11.52,
|
||||
11.16, 11.52, 11.88,11.16, 11.52, 11.88,11.16, 11.52, 11.88,11.16, 11.52, 11.88, 7.08, 7.32, 7.56, 7.08, 7.32, 7.56, 7.08, 7.32, 7.56, 7.08, 7.32, 7.56});
|
||||
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC},{14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f,
|
||||
17.04f, 17.52f, 18.f,17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,
|
||||
11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f});
|
||||
// auto expGradB('c', {oC},{});
|
||||
|
||||
input = 2.;
|
||||
|
@ -1093,14 +1093,14 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) {
|
|||
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
|
||||
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
|
||||
|
||||
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{ 0.014,0.032, 0.05 , 0.068,0.118,0.181, 0.244, 0.307,0.212,0.257, 0.302, 0.347,0.208,0.298, 0.388, 0.478,1.028,1.262, 1.496, 1.73 ,1.036,1.18 , 1.324, 1.468,
|
||||
0.928,1.018, 1.108, 1.198,2.9 ,3.134, 3.368, 3.602,2.188,2.332, 2.476, 2.62 ,1.202,1.274, 1.346, 1.418,3.142,3.313, 3.484, 3.655,2.048,2.147, 2.246, 2.345,
|
||||
0.086,0.212, 0.338, 0.464,0.694,0.973, 1.252, 1.531,0.716,0.869, 1.022, 1.175,1.216,1.522, 1.828, 2.134,3.908,4.574, 5.24 , 5.906,2.908,3.268, 3.628, 3.988,
|
||||
3.664,3.97 , 4.276, 4.582,9.236,9.902,10.568,11.234,5.788,6.148, 6.508, 6.868,3.002,3.182, 3.362, 3.542,7.174,7.561, 7.948, 8.335,4.28 ,4.487, 4.694, 4.901});
|
||||
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f,0.118f,0.181f, 0.244f, 0.307f,0.212f,0.257f, 0.302f, 0.347f,0.208f,0.298f, 0.388f, 0.478f,1.028f,1.262f, 1.496f, 1.73f,1.036f,1.18f, 1.324f, 1.468f,
|
||||
0.928f,1.018f, 1.108f, 1.198f,2.9f,3.134f, 3.368f, 3.602f,2.188f,2.332f, 2.476f, 2.62f, 1.202f,1.274f, 1.346f, 1.418f,3.142f,3.313f, 3.484f, 3.655f,2.048f,2.147f, 2.246f, 2.345f,
|
||||
0.086f,0.212f, 0.338f, 0.464f,0.694f,0.973f, 1.252f, 1.531f,0.716f,0.869f, 1.022f, 1.175f,1.216f,1.522f, 1.828f, 2.134f,3.908f,4.574f, 5.24f, 5.906f,2.908f,3.268f, 3.628f, 3.988f,
|
||||
3.664f,3.97f, 4.276f, 4.582f,9.236f,9.902f,10.568f,11.234f,5.788f,6.148f, 6.508f, 6.868f,3.002f,3.182f, 3.362f, 3.542f,7.174f,7.561f, 7.948f, 8.335f,4.28f,4.487f, 4.694f, 4.901f});
|
||||
|
||||
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC},{1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,
|
||||
1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,
|
||||
1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16,1.84, 2., 2.16});
|
||||
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC},{1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,
|
||||
1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,
|
||||
1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f});
|
||||
// auto expGradB('c', {oC},{});
|
||||
|
||||
input = 2.;
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -1092,15 +1092,15 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_06) {
|
|||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, NTH_Element_Test_7) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<float>('c', {2, 3, 4}, {0.7788, 0.8012, 0.7244, 0.2309,
|
||||
0.7271, 0.1804, 0.5056, 0.8925,
|
||||
0.5461, 0.9234, 0.0856, 0.7938,
|
||||
NDArray input = NDArrayFactory::create<float>('c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f,
|
||||
0.7271f, 0.1804f, 0.5056f, 0.8925f,
|
||||
0.5461f, 0.9234f, 0.0856f, 0.7938f,
|
||||
|
||||
0.6591, 0.5555, 0.1596, 0.3087,
|
||||
0.1548, 0.4695, 0.9939, 0.6113,
|
||||
0.6765, 0.1800, 0.6750, 0.2246});
|
||||
0.6591f, 0.5555f, 0.1596f, 0.3087f,
|
||||
0.1548f, 0.4695f, 0.9939f, 0.6113f,
|
||||
0.6765f, 0.1800f, 0.6750f, 0.2246f});
|
||||
NDArray n = NDArrayFactory::create<int>(2);
|
||||
NDArray exp = NDArrayFactory::create<float>('c', {2,3}, {0.7788, 0.7271, 0.7938, 0.5555, 0.6113, 0.675});
|
||||
NDArray exp = NDArrayFactory::create<float>('c', {2,3}, {0.7788f, 0.7271f, 0.7938f, 0.5555f, 0.6113f, 0.675f});
|
||||
|
||||
//input.linspace(1.f);
|
||||
|
||||
|
@ -1119,15 +1119,15 @@ TEST_F(DeclarableOpsTests10, NTH_Element_Test_7) {
|
|||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, NTH_Element_Test_8) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<float>('c', {2, 3, 4}, {0.7788, 0.8012, 0.7244, 0.2309,
|
||||
0.7271, 0.1804, 0.5056, 0.8925,
|
||||
0.5461, 0.9234, 0.0856, 0.7938,
|
||||
NDArray input = NDArrayFactory::create<float>('c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f,
|
||||
0.7271f, 0.1804f, 0.5056f, 0.8925f,
|
||||
0.5461f, 0.9234f, 0.0856f, 0.7938f,
|
||||
|
||||
0.6591, 0.5555, 0.1596, 0.3087,
|
||||
0.1548, 0.4695, 0.9939, 0.6113,
|
||||
0.6765, 0.1800, 0.6750, 0.2246});
|
||||
0.6591f, 0.5555f, 0.1596f, 0.3087f,
|
||||
0.1548f, 0.4695f, 0.9939f, 0.6113f,
|
||||
0.6765f, 0.1800f, 0.6750f, 0.2246f});
|
||||
NDArray n = NDArrayFactory::create<int>(2);
|
||||
NDArray exp = NDArrayFactory::create<float>('c', {2,3}, {0.7244, 0.5056, 0.5461, 0.3087, 0.4695, 0.2246});
|
||||
NDArray exp = NDArrayFactory::create<float>('c', {2,3}, {0.7244f, 0.5056f, 0.5461f, 0.3087f, 0.4695f, 0.2246f});
|
||||
|
||||
//input.linspace(1.f);
|
||||
|
||||
|
@ -1359,10 +1359,10 @@ TEST_F(DeclarableOpsTests10, broadcast_to_test10) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<float>('c', {1, 2,3,4});
|
||||
NDArray input = NDArrayFactory::create<double>('c', {1, 2,3,4});
|
||||
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
|
||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
|
||||
4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10.,
|
||||
8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12.,
|
||||
9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6,
|
||||
|
@ -1416,10 +1416,10 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<float>('c', {1, 2,3,4});
|
||||
NDArray input = NDArrayFactory::create<double>('c', {1, 2,3,4});
|
||||
NDArray size = NDArrayFactory::create<int>({10, 10});
|
||||
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
|
||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4,
|
||||
4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10.,
|
||||
8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12.,
|
||||
9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6,
|
||||
|
@ -1472,10 +1472,10 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<float>('c', {1, 2,3,4});
|
||||
NDArray input = NDArrayFactory::create<double>('c', {1, 2,3,4});
|
||||
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {1, 10, 10, 4},
|
||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4},
|
||||
{ 1., 2., 3., 4. ,
|
||||
1.8888888, 2.8888888, 3.8888888, 4.888889,
|
||||
2.7777777, 3.7777777, 4.7777777, 5.7777777,
|
||||
|
@ -1602,9 +1602,9 @@ TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<float>('c', {1, 2,3,4});
|
||||
NDArray input = NDArrayFactory::create<double>('c', {1, 2,3,4});
|
||||
NDArray size = NDArrayFactory::create<int>({10, 10});
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {1, 10, 10, 4},
|
||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 10, 10, 4},
|
||||
{ 1., 2., 3., 4. ,
|
||||
1.8888888, 2.8888888, 3.8888888, 4.888889,
|
||||
2.7777777, 3.7777777, 4.7777777, 5.7777777,
|
||||
|
@ -1750,10 +1750,10 @@ TEST_F(DeclarableOpsTests10, LinSpace_Test1) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) {
|
||||
|
||||
NDArray input = NDArrayFactory::create<float>('c', {1, 2, 3, 4});
|
||||
NDArray input = NDArrayFactory::create<double>('c', {1, 2, 3, 4});
|
||||
//NDArray<float> paddings('c', {3,2}, {0,0, 0,1, 0,0});
|
||||
//NDArray<float> expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.});
|
||||
NDArray expected = NDArrayFactory::create<float>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
|
||||
NDArray expected = NDArrayFactory::create<double>('c', {1, 4, 5, 4}, { 1, 2, 3, 4,
|
||||
1, 2, 3, 4,
|
||||
5, 6, 7, 8,
|
||||
5, 6, 7, 8,
|
||||
|
@ -1926,8 +1926,8 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) {
|
|||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) {
|
||||
int axis = 0;
|
||||
NDArray images = NDArrayFactory::create<float>('c', {1,2,2,1}, {1,2,3,4});
|
||||
NDArray boxes = NDArrayFactory::create<float>('c', {1,4}, {0,0,1,1});
|
||||
NDArray images = NDArrayFactory::create<float>('c', {1,2,2,1}, {1.f, 2.f, 3.f, 4.f});
|
||||
NDArray boxes = NDArrayFactory::create<float>('c', {1,4}, {0.f, 0.f, 1.f, 1.f});
|
||||
NDArray boxI = NDArrayFactory::create<int>('c', {1}, {axis});
|
||||
NDArray cropSize = NDArrayFactory::create<int>({1, 1});
|
||||
|
||||
|
|
|
@ -3238,11 +3238,11 @@ TEST_F(DeclarableOpsTests8, Test_Moments_7) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_01) {
|
||||
|
||||
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 2, 5}, { 1, 2., 3, 4, 5,
|
||||
6., 7., 8, 9, 10}
|
||||
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 2, 5}, { 1.f, 2.f, 3.f, 4.f, 5.f,
|
||||
6.f, 7.f, 8.f, 9.f, 10.f}
|
||||
);
|
||||
|
||||
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 2, 5}, {0.2581989 , 0.3592106 , 0.40089184, 0.53935987, 0.70014, 0.4898979 , 0.46056613, 0.43971977, 0.5240003 , 0.6375767 }// 0.72760683, 0.4850712, 0.5848977, 0.67488194,
|
||||
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 2, 5}, {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, 0.4898979f, 0.46056613f, 0.43971977f, 0.5240003f, 0.6375767f}// 0.72760683, 0.4850712, 0.5848977, 0.67488194,
|
||||
// 0.7581754, 0.58321184, 0.86747235, 0.4048204}
|
||||
);
|
||||
|
||||
|
@ -3262,10 +3262,10 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_01) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_02) {
|
||||
|
||||
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 6}, { 1, 2., 3, 4, 5, 6});
|
||||
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 6}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||
|
||||
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 6}, {
|
||||
0.2581989 , 0.3592106 , 0.40089184, 0.4193139, 0.5360563, 0.67936623}
|
||||
0.2581989f, 0.3592106f, 0.40089184f, 0.4193139f, 0.5360563f, 0.67936623f}
|
||||
);
|
||||
|
||||
nd4j::ops::lrn op;
|
||||
|
@ -3284,8 +3284,8 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_02) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_03) {
|
||||
|
||||
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 10}, { 1, 2., 3, 4, 5, 6, 7, 8, 9, 10});
|
||||
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 10}, {0.10425719, 0.16843036, 0.2095291 , 0.23652494, 0.25449327,0.3053919 , 0.35675305, 0.4098524 , 0.46662825, 0.52999896});
|
||||
auto x = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 10}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f});
|
||||
auto exp = NDArrayFactory::create<TypeParam>('c', {1, 1, 1, 10}, {0.10425719f, 0.16843036f, 0.2095291f, 0.23652494f, 0.25449327f, 0.3053919f, 0.35675305f, 0.4098524f, 0.46662825f, 0.52999896f});
|
||||
|
||||
nd4j::ops::lrn op;
|
||||
auto results = op.execute({&x}, {1.0, 1.0, 0.5}, {5}, {}, false, nd4j::DataType::DOUBLE);
|
||||
|
@ -3303,17 +3303,17 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_03) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_1) {
|
||||
|
||||
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, { 5.5, 0., 0.3, 5.5,
|
||||
8.6, 0., 0., 0.4,
|
||||
1.5, 1., 1.3, 1.5,
|
||||
2.6, 2., 3., 1.4}
|
||||
auto x = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f,
|
||||
8.6f, 0.f, 0.f, 0.4f,
|
||||
1.5f, 1.f, 1.3f, 1.5f,
|
||||
2.6f, 2.f, 3.f, 1.4f}
|
||||
);
|
||||
|
||||
auto exp = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2}, {
|
||||
0.98386997, 0., 0.05358852, 0.9824562,
|
||||
0.99330735, 0., 0., 0.37139067,
|
||||
0.72760683, 0.4850712, 0.5848977, 0.67488194,
|
||||
0.7581754, 0.58321184, 0.86747235, 0.4048204}
|
||||
0.98386997f, 0.f, 0.05358852f, 0.9824562f,
|
||||
0.99330735f, 0.f, 0.f, 0.37139067f,
|
||||
0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f,
|
||||
0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}
|
||||
);
|
||||
|
||||
nd4j::ops::lrn op;
|
||||
|
@ -3336,61 +3336,61 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_2) {
|
|||
x.linspace(1);
|
||||
|
||||
auto exp = NDArrayFactory::create<TypeParam>('c', {3, 3, 5, 5}, {
|
||||
0.2581989 ,0.3592106 , 0.40089184, 0.53935987, 0.70014,
|
||||
0.4898979 ,0.46056613, 0.43971977, 0.5240002 , 0.6375767,
|
||||
0.5274096 ,0.47771242, 0.4443308 , 0.5163977 , 0.61701745,
|
||||
0.5424508 ,0.48452914, 0.44570294, 0.5123918 , 0.6068971,
|
||||
0.5505386 ,0.4881662 , 0.4462865 , 0.5099462 , 0.60088515,
|
||||
0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f,
|
||||
0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f,
|
||||
0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f,
|
||||
0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f,
|
||||
0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f,
|
||||
|
||||
0.5555859 , 0.49042296, 0.44658744, 0.5083028 , 0.59690416,
|
||||
0.55903524, 0.4919585 , 0.44676256, 0.5071239 , 0.59407425,
|
||||
0.5615412 , 0.49307042, 0.44687328, 0.50623745, 0.5919596 ,
|
||||
0.56344414, 0.49391258, 0.4469477 , 0.5055468 , 0.59031945,
|
||||
0.56493837, 0.49457246, 0.4470002 , 0.5049936 , 0.5890103 ,
|
||||
0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f,
|
||||
0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f,
|
||||
0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f,
|
||||
0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f,
|
||||
0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f,
|
||||
|
||||
0.56614274, 0.49510333, 0.44703856, 0.50454074, 0.5879411 ,
|
||||
0.567134 , 0.49553978, 0.4470674 , 0.504163 , 0.5870515 ,
|
||||
0.5679643 , 0.4959048 , 0.44708967, 0.5038433 , 0.5862998 ,
|
||||
0.56866974, 0.4962146 , 0.44710726, 0.5035692 , 0.58565617,
|
||||
0.56927663, 0.49648085, 0.4471213 , 0.5033315 , 0.5850988 ,
|
||||
0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f,
|
||||
0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f,
|
||||
0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f,
|
||||
0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f,
|
||||
0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f,
|
||||
|
||||
|
||||
0.56980413, 0.49671215, 0.44713274, 0.50312346, 0.58461165,
|
||||
0.57026696, 0.49691492, 0.4471422 , 0.50293994, 0.58418214,
|
||||
0.5706764 , 0.49709415, 0.44715008, 0.5027767 , 0.5838005 ,
|
||||
0.571041 , 0.4972537 , 0.44715673, 0.50263065, 0.58345926,
|
||||
0.57136786, 0.49739665, 0.44716236, 0.5024992 , 0.58315235,
|
||||
0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f,
|
||||
0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f,
|
||||
0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f,
|
||||
0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f,
|
||||
0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f,
|
||||
|
||||
0.5716625 , 0.49752548, 0.4471672 , 0.5023803, 0.5828747 ,
|
||||
0.5719295 , 0.49764213, 0.44717142, 0.5022721, 0.5826225 ,
|
||||
0.57217246, 0.49774826, 0.44717506, 0.5021734, 0.58239233,
|
||||
0.5723947 , 0.4978453 , 0.44717824, 0.5020829, 0.58218133,
|
||||
0.57259864, 0.49793428, 0.44718108, 0.5019997, 0.5819874 ,
|
||||
0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f,
|
||||
0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f,
|
||||
0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f,
|
||||
0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f,
|
||||
0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f,
|
||||
|
||||
0.5727864 , 0.49801624, 0.44718358, 0.5019227, 0.5818083 ,
|
||||
0.57296 , 0.49809194, 0.44718578, 0.5018515, 0.5816426 ,
|
||||
0.5731208 , 0.49816203, 0.44718775, 0.5017854, 0.58148885,
|
||||
0.57327026, 0.49822718, 0.4471895 , 0.5017239, 0.5813457 ,
|
||||
0.57340944, 0.49828786, 0.44719115, 0.5016664, 0.581212 ,
|
||||
0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f,
|
||||
0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f,
|
||||
0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f,
|
||||
0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f,
|
||||
0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f,
|
||||
|
||||
|
||||
0.57353944, 0.4983446 , 0.44719255, 0.50161266, 0.58108705,
|
||||
0.5736612 , 0.49839762, 0.4471939 , 0.50156236, 0.5809699 ,
|
||||
0.5737754 , 0.4984474 , 0.44719502, 0.501515 , 0.58085984,
|
||||
0.5738828 , 0.49849418, 0.4471962 , 0.50147045, 0.5807564 ,
|
||||
0.5739839 , 0.49853817, 0.44719717, 0.5014284 , 0.5806588 ,
|
||||
0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f,
|
||||
0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f,
|
||||
0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f,
|
||||
0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f,
|
||||
0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f,
|
||||
|
||||
0.5740793 , 0.49857965, 0.4471981 , 0.5013887 , 0.5805666 ,
|
||||
0.5741694 , 0.49861887, 0.44719887, 0.50135124, 0.58047944,
|
||||
0.57425463, 0.49865603, 0.44719967, 0.5013157 , 0.5803969 ,
|
||||
0.5743354 , 0.4986912 , 0.44720036, 0.5012819 , 0.5803186 ,
|
||||
0.57441217, 0.49872455, 0.44720104, 0.5012499 , 0.58024424,
|
||||
0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f,
|
||||
0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f,
|
||||
0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f,
|
||||
0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f,
|
||||
0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f,
|
||||
|
||||
0.57448506, 0.4987563 , 0.44720164, 0.5012194 , 0.58017343,
|
||||
0.57455444, 0.4987865 , 0.4472022 , 0.5011904 , 0.5801061,
|
||||
0.57462054, 0.49881527, 0.44720277, 0.5011627 , 0.5800419,
|
||||
0.57468355, 0.49884263, 0.44720328, 0.50113624, 0.5799805,
|
||||
0.57474375, 0.49886885, 0.44720373, 0.50111103, 0.5799219 }
|
||||
0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f,
|
||||
0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f,
|
||||
0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f,
|
||||
0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f,
|
||||
0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f }
|
||||
);
|
||||
//
|
||||
nd4j::ops::lrn op;
|
||||
|
@ -3413,61 +3413,61 @@ TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_3) {
|
|||
x.linspace(1);
|
||||
|
||||
auto exp = NDArrayFactory::create<TypeParam>('c', {3, 3, 5, 5}, {
|
||||
0.2581989 ,0.3592106 , 0.40089184, 0.53935987, 0.70014,
|
||||
0.4898979 ,0.46056613, 0.43971977, 0.5240002 , 0.6375767,
|
||||
0.5274096 ,0.47771242, 0.4443308 , 0.5163977 , 0.61701745,
|
||||
0.5424508 ,0.48452914, 0.44570294, 0.5123918 , 0.6068971,
|
||||
0.5505386 ,0.4881662 , 0.4462865 , 0.5099462 , 0.60088515,
|
||||
0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f,
|
||||
0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f,
|
||||
0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f,
|
||||
0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f,
|
||||
0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f,
|
||||
|
||||
0.5555859 , 0.49042296, 0.44658744, 0.5083028 , 0.59690416,
|
||||
0.55903524, 0.4919585 , 0.44676256, 0.5071239 , 0.59407425,
|
||||
0.5615412 , 0.49307042, 0.44687328, 0.50623745, 0.5919596 ,
|
||||
0.56344414, 0.49391258, 0.4469477 , 0.5055468 , 0.59031945,
|
||||
0.56493837, 0.49457246, 0.4470002 , 0.5049936 , 0.5890103 ,
|
||||
0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f,
|
||||
0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f,
|
||||
0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f,
|
||||
0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f,
|
||||
0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f,
|
||||
|
||||
0.56614274, 0.49510333, 0.44703856, 0.50454074, 0.5879411 ,
|
||||
0.567134 , 0.49553978, 0.4470674 , 0.504163 , 0.5870515 ,
|
||||
0.5679643 , 0.4959048 , 0.44708967, 0.5038433 , 0.5862998 ,
|
||||
0.56866974, 0.4962146 , 0.44710726, 0.5035692 , 0.58565617,
|
||||
0.56927663, 0.49648085, 0.4471213 , 0.5033315 , 0.5850988 ,
|
||||
0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f,
|
||||
0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f,
|
||||
0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f,
|
||||
0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f,
|
||||
0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f,
|
||||
|
||||
|
||||
0.56980413, 0.49671215, 0.44713274, 0.50312346, 0.58461165,
|
||||
0.57026696, 0.49691492, 0.4471422 , 0.50293994, 0.58418214,
|
||||
0.5706764 , 0.49709415, 0.44715008, 0.5027767 , 0.5838005 ,
|
||||
0.571041 , 0.4972537 , 0.44715673, 0.50263065, 0.58345926,
|
||||
0.57136786, 0.49739665, 0.44716236, 0.5024992 , 0.58315235,
|
||||
0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f,
|
||||
0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f,
|
||||
0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f,
|
||||
0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f,
|
||||
0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f,
|
||||
|
||||
0.5716625 , 0.49752548, 0.4471672 , 0.5023803, 0.5828747 ,
|
||||
0.5719295 , 0.49764213, 0.44717142, 0.5022721, 0.5826225 ,
|
||||
0.57217246, 0.49774826, 0.44717506, 0.5021734, 0.58239233,
|
||||
0.5723947 , 0.4978453 , 0.44717824, 0.5020829, 0.58218133,
|
||||
0.57259864, 0.49793428, 0.44718108, 0.5019997, 0.5819874 ,
|
||||
0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f,
|
||||
0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f,
|
||||
0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f,
|
||||
0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f,
|
||||
0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f,
|
||||
|
||||
0.5727864 , 0.49801624, 0.44718358, 0.5019227, 0.5818083 ,
|
||||
0.57296 , 0.49809194, 0.44718578, 0.5018515, 0.5816426 ,
|
||||
0.5731208 , 0.49816203, 0.44718775, 0.5017854, 0.58148885,
|
||||
0.57327026, 0.49822718, 0.4471895 , 0.5017239, 0.5813457 ,
|
||||
0.57340944, 0.49828786, 0.44719115, 0.5016664, 0.581212 ,
|
||||
0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f,
|
||||
0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f,
|
||||
0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f,
|
||||
0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f,
|
||||
0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f,
|
||||
|
||||
|
||||
0.57353944, 0.4983446 , 0.44719255, 0.50161266, 0.58108705,
|
||||
0.5736612 , 0.49839762, 0.4471939 , 0.50156236, 0.5809699 ,
|
||||
0.5737754 , 0.4984474 , 0.44719502, 0.501515 , 0.58085984,
|
||||
0.5738828 , 0.49849418, 0.4471962 , 0.50147045, 0.5807564 ,
|
||||
0.5739839 , 0.49853817, 0.44719717, 0.5014284 , 0.5806588 ,
|
||||
0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f,
|
||||
0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f,
|
||||
0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f,
|
||||
0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f,
|
||||
0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f,
|
||||
|
||||
0.5740793 , 0.49857965, 0.4471981 , 0.5013887 , 0.5805666 ,
|
||||
0.5741694 , 0.49861887, 0.44719887, 0.50135124, 0.58047944,
|
||||
0.57425463, 0.49865603, 0.44719967, 0.5013157 , 0.5803969 ,
|
||||
0.5743354 , 0.4986912 , 0.44720036, 0.5012819 , 0.5803186 ,
|
||||
0.57441217, 0.49872455, 0.44720104, 0.5012499 , 0.58024424,
|
||||
0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f,
|
||||
0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f,
|
||||
0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f,
|
||||
0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f,
|
||||
0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f,
|
||||
|
||||
0.57448506, 0.4987563 , 0.44720164, 0.5012194 , 0.58017343,
|
||||
0.57455444, 0.4987865 , 0.4472022 , 0.5011904 , 0.5801061,
|
||||
0.57462054, 0.49881527, 0.44720277, 0.5011627 , 0.5800419,
|
||||
0.57468355, 0.49884263, 0.44720328, 0.50113624, 0.5799805,
|
||||
0.57474375, 0.49886885, 0.44720373, 0.50111103, 0.5799219 }
|
||||
0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f,
|
||||
0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f,
|
||||
0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f,
|
||||
0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f,
|
||||
0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f }
|
||||
);
|
||||
//
|
||||
nd4j::ops::lrn op;
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include "testlayers.h"
|
||||
#include <array/ArrayOptions.h>
|
||||
#include <AffinityManager.h>
|
||||
#include <NDArray.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include <ops/declarable/headers/broadcastable.h>
|
||||
#include <MmulHelper.h>
|
||||
#include <thread>
|
||||
|
||||
|
||||
using namespace nd4j;
|
||||
|
||||
class MultiDeviceTests : public testing::Test {
|
||||
public:
|
||||
|
||||
};
|
||||
|
||||
void createArrays(int limit, std::vector<NDArray*> &arrays) {
|
||||
auto deviceId = AffinityManager::currentDeviceId();
|
||||
auto numDevices = AffinityManager::numberOfDevices();
|
||||
|
||||
for (int e = 0; e < limit; e++) {
|
||||
auto value = deviceId * limit + e;
|
||||
arrays[value] = NDArrayFactory::create_<float>('c', {10});
|
||||
arrays[value]->assign(value);
|
||||
//nd4j_printf("device_%i; value: [%i]; mean: [%f]\n", deviceId, value, arrays[value]->meanNumber().e<float>(0));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(MultiDeviceTests, test_multi_device_migration_1) {
|
||||
auto deviceId = AffinityManager::currentDeviceId();
|
||||
auto numDevices = AffinityManager::numberOfDevices();
|
||||
auto numArrays = 10;
|
||||
std::vector<NDArray*> arrays(numDevices * numArrays);
|
||||
|
||||
// filling list of arrays on multiple threads
|
||||
for (int e = 0; e < numDevices; e++) {
|
||||
std::thread t1(createArrays, numArrays, std::ref(arrays));
|
||||
|
||||
t1.join();
|
||||
}
|
||||
|
||||
// at this moment all arrays are build, so we can test migration
|
||||
for (int e = 0; e < arrays.size(); e++) {
|
||||
ASSERT_NEAR((float) e, arrays[e]->meanNumber().e<float>(0), 1e-5f);
|
||||
delete arrays[e];
|
||||
}
|
||||
}
|
|
@ -438,6 +438,7 @@ public class AtomicAllocator implements Allocator {
|
|||
Long allocId = objectsTracker.getAndIncrement();
|
||||
point.setObjectId(allocId);
|
||||
point.setConstant(true);
|
||||
point.setDeviceId(Nd4j.getAffinityManager().getDeviceForCurrentThread());
|
||||
|
||||
allocationsMap.put(allocId, point);
|
||||
|
||||
|
|
|
@ -68,11 +68,7 @@ public class CudaAffinityManager extends BasicAffinityManager {
|
|||
*/
|
||||
@Override
|
||||
public Integer getDeviceForCurrentThread() {
|
||||
val id = NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
|
||||
if (!affinityMap.containsKey(Thread.currentThread().getId()))
|
||||
affinityMap.put(Thread.currentThread().getId(), id);
|
||||
|
||||
return id;
|
||||
return NativeOpsHolder.getInstance().getDeviceNativeOps().getDevice();
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -205,7 +201,6 @@ public class CudaAffinityManager extends BasicAffinityManager {
|
|||
int currentDeviceId = getDeviceForCurrentThread();
|
||||
|
||||
if (currentDeviceId != deviceId.intValue()) {
|
||||
Nd4j.getMemoryManager().releaseCurrentContext();
|
||||
unsafeSetDevice(deviceId);
|
||||
}
|
||||
|
||||
|
@ -215,7 +210,6 @@ public class CudaAffinityManager extends BasicAffinityManager {
|
|||
INDArray result = Nd4j.createArrayFromShapeBuffer(newDataBuffer, newShapeBuffer);
|
||||
|
||||
if (currentDeviceId != deviceId.intValue()) {
|
||||
Nd4j.getMemoryManager().releaseCurrentContext();
|
||||
unsafeSetDevice(currentDeviceId);
|
||||
}
|
||||
|
||||
|
@ -238,7 +232,6 @@ public class CudaAffinityManager extends BasicAffinityManager {
|
|||
int currentDeviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||
|
||||
if (currentDeviceId != deviceId) {
|
||||
Nd4j.getMemoryManager().releaseCurrentContext();
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
|
||||
}
|
||||
|
||||
|
@ -246,7 +239,6 @@ public class CudaAffinityManager extends BasicAffinityManager {
|
|||
AtomicAllocator.getInstance().memcpy(dstBuffer, buffer);
|
||||
|
||||
if (currentDeviceId != deviceId) {
|
||||
Nd4j.getMemoryManager().releaseCurrentContext();
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(currentDeviceId);
|
||||
}
|
||||
|
||||
|
|
|
@ -188,13 +188,15 @@ public class SynchronousFlowController implements FlowController {
|
|||
}
|
||||
|
||||
if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
|
||||
((JCublasNDArray) result).setShapeInfoDataBuffer(
|
||||
Nd4j.getConstantHandler().relocateConstantSpace(result.shapeInfoDataBuffer()));
|
||||
((JCublasNDArray) result).setShapeInfoDataBuffer(Nd4j.getExecutioner().createShapeInfo(result.shape(), result.stride(), result.elementWiseStride(), result.ordering(), result.dataType(), result.isEmpty()));
|
||||
}
|
||||
|
||||
allocator.getAllocationPoint(result).setCurrentContext(context);
|
||||
}
|
||||
|
||||
if (operands == null)
|
||||
return context;
|
||||
|
||||
for (INDArray operand : operands) {
|
||||
if (operand == null || operand.isEmpty())
|
||||
continue;
|
||||
|
@ -213,8 +215,7 @@ public class SynchronousFlowController implements FlowController {
|
|||
}
|
||||
|
||||
if (pointShape.getDeviceId() != cId && pointShape.getDeviceId() >= 0) {
|
||||
((JCublasNDArray) operand).setShapeInfoDataBuffer(
|
||||
Nd4j.getConstantHandler().relocateConstantSpace(operand.shapeInfoDataBuffer()));
|
||||
((JCublasNDArray) operand).setShapeInfoDataBuffer(Nd4j.getExecutioner().createShapeInfo(operand.shape(), operand.stride(), operand.elementWiseStride(), operand.ordering(), operand.dataType(), operand.isEmpty()));
|
||||
}
|
||||
|
||||
prepareDelayedMemory(operand);
|
||||
|
|
|
@ -819,10 +819,10 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
val ohPtr = dstPoint.getHostPointer();
|
||||
|
||||
// FIXME: cross-thread access, might cause problems
|
||||
if (!dstPoint.isActualOnHostSide())
|
||||
if (dstPoint.getHostPointer() != null && !dstPoint.isActualOnHostSide())
|
||||
AtomicAllocator.getInstance().synchronizeHostData(buffer);
|
||||
|
||||
if (!dstPoint.isActualOnHostSide())
|
||||
if (dstPoint.getHostPointer() != null && !dstPoint.isActualOnHostSide())
|
||||
throw new RuntimeException("Buffer synchronization failed");
|
||||
|
||||
if (buffer.isAttached() || dstPoint.isAttached()) {
|
||||
|
@ -832,10 +832,14 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
|
||||
if (workspace == null) {
|
||||
// if we're out of workspace, we should mark our buffer as detached, so gc will pick it up eventually
|
||||
val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false);
|
||||
// host part is optional
|
||||
if (dstPoint.getHostPointer() != null) {
|
||||
val pairH = alloc(AllocationStatus.HOST, dstPoint, dstPoint.getShape(), false);
|
||||
dstPoint.getPointers().setHostPointer(pairH.getHostPointer());
|
||||
}
|
||||
|
||||
val pairD = alloc(AllocationStatus.DEVICE, dstPoint, dstPoint.getShape(), false);
|
||||
dstPoint.getPointers().setDevicePointer(pairD.getDevicePointer());
|
||||
dstPoint.getPointers().setHostPointer(pairH.getHostPointer());
|
||||
|
||||
//log.info("New host pointer: {}; Old host pointer: {}", dstPoint.getHostPointer().address(), ohPtr.address());
|
||||
|
||||
|
@ -869,7 +873,11 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
Nd4j.getMemoryManager().memcpy(nBuffer, buffer);
|
||||
|
||||
dstPoint.getPointers().setDevicePointer(nBuffer.getAllocationPoint().getDevicePointer());
|
||||
dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer());
|
||||
|
||||
if (dstPoint.getHostPointer() != null) {
|
||||
dstPoint.getPointers().setHostPointer(nBuffer.getAllocationPoint().getHostPointer());
|
||||
}
|
||||
|
||||
dstPoint.setDeviceId(deviceId);
|
||||
|
||||
dstPoint.tickDeviceRead();
|
||||
|
@ -885,6 +893,17 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
throw new RuntimeException("Can't relocateObject() for constant buffer");
|
||||
} else {
|
||||
// log.info("Free relocateObject: deviceId: {}, pointer: {}", deviceId, dstPoint.getPointers().getDevicePointer().address());
|
||||
val context = getCudaContext();
|
||||
if (dstPoint.getHostPointer() == null) {
|
||||
((BaseCudaDataBuffer) buffer).lazyAllocateHostPointer();
|
||||
|
||||
if (nativeOps.memcpyAsync(dstPoint.getHostPointer(), dstPoint.getDevicePointer(),
|
||||
buffer.length() * buffer.getElementSize(), 2, context.getSpecialStream()) == 0)
|
||||
throw new ND4JIllegalStateException("memcpyAsync failed");
|
||||
|
||||
context.syncSpecialStream();
|
||||
}
|
||||
|
||||
memoryProvider.free(dstPoint);
|
||||
deviceMemoryTracker.subFromAllocation(Thread.currentThread().getId(), dstPoint.getDeviceId(), AllocationUtils.getRequiredMemory(dstPoint.getShape()));
|
||||
|
||||
|
@ -893,7 +912,6 @@ public class CudaZeroHandler implements MemoryHandler {
|
|||
|
||||
val profD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||
|
||||
CudaContext context = getCudaContext();
|
||||
if (nativeOps.memcpyAsync(dstPoint.getDevicePointer(), dstPoint.getHostPointer(),
|
||||
buffer.length() * buffer.getElementSize(), 1, context.getSpecialStream()) == 0)
|
||||
throw new ND4JIllegalStateException("memcpyAsync failed");
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.nd4j.linalg.jcublas;
|
||||
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.nd4j.jita.allocator.enums.AllocationStatus;
|
||||
import org.nd4j.jita.allocator.enums.CudaConstants;
|
||||
|
@ -54,7 +55,7 @@ import java.util.concurrent.atomic.AtomicLong;
|
|||
* @author Adam Gibson
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
|
||||
@Slf4j
|
||||
public class JCublasNDArray extends BaseNDArray {
|
||||
|
||||
|
||||
|
@ -737,14 +738,12 @@ public class JCublasNDArray extends BaseNDArray {
|
|||
if (!this.isView()) {
|
||||
Nd4j.getExecutioner().commit();
|
||||
|
||||
DataBuffer buffer = Nd4j.createBuffer(this.lengthLong(), false);
|
||||
val buffer = Nd4j.createBuffer(this.dataType(), this.lengthLong(), false);
|
||||
|
||||
AllocationPoint pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
||||
AllocationPoint pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
|
||||
val pointDst = AtomicAllocator.getInstance().getAllocationPoint(buffer);
|
||||
val pointSrc = AtomicAllocator.getInstance().getAllocationPoint(this.data);
|
||||
|
||||
// CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext();
|
||||
|
||||
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
|
||||
val context = AtomicAllocator.getInstance().getFlowController().prepareAction(pointDst, pointSrc);
|
||||
|
||||
MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE;
|
||||
val perfD = PerformanceTracker.getInstance().helperStartTransaction();
|
||||
|
@ -764,14 +763,12 @@ public class JCublasNDArray extends BaseNDArray {
|
|||
PerformanceTracker.getInstance().helperRegisterTransaction(pointDst.getDeviceId(), perfD, pointDst.getNumberOfBytes(), direction);
|
||||
|
||||
if (pointDst.getDeviceId() != Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId()) {
|
||||
//log.info("Swapping [{}] -> [{}]", pointDst.getDeviceId(), Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId());
|
||||
pointDst.setDeviceId(Nd4j.getMemoryManager().getCurrentWorkspace().getDeviceId());
|
||||
}
|
||||
|
||||
copy = Nd4j.createArrayFromShapeBuffer(buffer, this.shapeInfoDataBuffer());
|
||||
|
||||
// tag buffer as valid on device side
|
||||
pointDst.tickHostRead();
|
||||
pointDst.tickDeviceWrite();
|
||||
|
||||
AtomicAllocator.getInstance().getFlowController().registerAction(context, pointDst, pointSrc);
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.bytedeco.javacpp.indexer.*;
|
|||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.buffer.LongBuffer;
|
||||
import org.nd4j.linalg.api.buffer.Utf8Buffer;
|
||||
import org.nd4j.linalg.api.buffer.factory.DataBufferFactory;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
|
@ -92,6 +93,8 @@ public class CudaDataBufferFactory implements DataBufferFactory {
|
|||
return new CudaByteDataBuffer(underlyingBuffer, length, offset);
|
||||
case BOOL:
|
||||
return new CudaBoolDataBuffer(underlyingBuffer, length, offset);
|
||||
case UTF8:
|
||||
return new Utf8Buffer(underlyingBuffer, length, offset);
|
||||
default:
|
||||
throw new ND4JIllegalStateException("Unknown data buffer type: " + underlyingBuffer.dataType().toString());
|
||||
}
|
||||
|
|
|
@ -233,7 +233,7 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
throw new ND4JIllegalStateException("Op target dimension " + Arrays.toString(dimension)
|
||||
+ " contains element that higher then rank of op.X: [" + op.x().rank() + "]");
|
||||
|
||||
CudaContext context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
|
||||
val context = AtomicAllocator.getInstance().getFlowController().prepareAction(op.z(), op.x(), op.y());
|
||||
|
||||
if (CudaEnvironment.getInstance().getConfiguration().isDebug())
|
||||
lastOp.set(op.opName());
|
||||
|
@ -2491,6 +2491,11 @@ public class CudaExecutioner extends DefaultOpExecutioner {
|
|||
|
||||
nativeOps.execCustomOp2(null, op.opHash(), context.contextPointer());
|
||||
|
||||
for (val arr:op.outputArguments())
|
||||
AtomicAllocator.getInstance().registerAction(ctx, arr);
|
||||
|
||||
AtomicAllocator.getInstance().registerAction(ctx, null, op.inputArguments());
|
||||
|
||||
profilingConfigurableHookOut(op, st);
|
||||
|
||||
if (context.getOutputArrays().isEmpty())
|
||||
|
|
|
@ -81,10 +81,8 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
|||
|
||||
@Override
|
||||
public void setInputArray(int index, @NonNull INDArray array) {
|
||||
// FIXME: remove
|
||||
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE);
|
||||
val ctx = AtomicAllocator.getInstance().getFlowController().prepareAction(null, array);
|
||||
|
||||
val ctx = AtomicAllocator.getInstance().getDeviceContext();
|
||||
nativeOps.setGraphContextInputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
|
||||
|
||||
super.setInputArray(index, array);
|
||||
|
@ -92,9 +90,8 @@ public class CudaOpContext extends BaseOpContext implements OpContext {
|
|||
|
||||
@Override
|
||||
public void setOutputArray(int index, @NonNull INDArray array) {
|
||||
Nd4j.getAffinityManager().ensureLocation(array, AffinityManager.Location.EVERYWHERE);
|
||||
val ctx = AtomicAllocator.getInstance().getFlowController().prepareAction(array, null);
|
||||
|
||||
val ctx = AtomicAllocator.getInstance().getDeviceContext();
|
||||
nativeOps.setGraphContextOutputArray(context, index, array.isEmpty() ? null : array.data().addressPointer(), array.shapeInfoDataBuffer().addressPointer(), array.isEmpty() ? null : AtomicAllocator.getInstance().getPointer(array, ctx), AtomicAllocator.getInstance().getPointer(array.shapeInfoDataBuffer()));
|
||||
|
||||
super.setOutputArray(index, array);
|
||||
|
|
|
@ -9898,6 +9898,72 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
// #endif //LIBND4J_OPREGISTRATOR_H
|
||||
|
||||
|
||||
// Parsed from execution/ContextBuffers.h
|
||||
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
// #ifndef LIBND4J_CONTEXTBUFFERS_H
|
||||
// #define LIBND4J_CONTEXTBUFFERS_H
|
||||
|
||||
// #include <dll.h>
|
||||
// #include <pointercast.h>
|
||||
@Namespace("nd4j") @NoOffset public static class ContextBuffers extends Pointer {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public ContextBuffers(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public ContextBuffers(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public ContextBuffers position(long position) {
|
||||
return (ContextBuffers)super.position(position);
|
||||
}
|
||||
|
||||
public ContextBuffers() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); }
|
||||
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/);
|
||||
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); }
|
||||
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer);
|
||||
|
||||
public native Pointer reductionBuffer();
|
||||
public native Pointer scalarBuffer();
|
||||
public native Pointer allocationBuffer();
|
||||
|
||||
public native Pointer execStream();
|
||||
public native Pointer specialStream();
|
||||
|
||||
public native void setReductionBuffer(Pointer pointer);
|
||||
public native void setScalarBuffer(Pointer pointer);
|
||||
public native void setAllocationBuffer(Pointer pointer);
|
||||
|
||||
public native void triggerOwnership(@Cast("bool") boolean isOwner);
|
||||
|
||||
public native int deviceId();
|
||||
}
|
||||
|
||||
|
||||
|
||||
// #endif //DEV_TESTS_CONTEXTBUFFERS_H
|
||||
|
||||
|
||||
// Parsed from execution/LaunchContext.h
|
||||
|
||||
/*******************************************************************************
|
||||
|
@ -9971,6 +10037,9 @@ public static final int PREALLOC_SIZE = 33554432;
|
|||
|
||||
public static native LaunchContext defaultContext();
|
||||
|
||||
|
||||
public static native void swapContextBuffers(@ByRef ContextBuffers buffers);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -76,6 +76,7 @@ import org.bytedeco.javacpp.tools.InfoMapper;
|
|||
"ops/declarable/BooleanOp.h",
|
||||
"ops/declarable/LogicOp.h",
|
||||
"ops/declarable/OpRegistrator.h",
|
||||
"execution/ContextBuffers.h",
|
||||
"execution/LaunchContext.h",
|
||||
"array/ShapeDescriptor.h",
|
||||
"array/TadDescriptor.h",
|
||||
|
|
|
@ -22808,6 +22808,72 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
// #endif
|
||||
|
||||
// Parsed from execution/ContextBuffers.h
|
||||
|
||||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
// #ifndef LIBND4J_CONTEXTBUFFERS_H
|
||||
// #define LIBND4J_CONTEXTBUFFERS_H
|
||||
|
||||
// #include <dll.h>
|
||||
// #include <pointercast.h>
|
||||
@Namespace("nd4j") @NoOffset public static class ContextBuffers extends Pointer {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public ContextBuffers(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public ContextBuffers(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public ContextBuffers position(long position) {
|
||||
return (ContextBuffers)super.position(position);
|
||||
}
|
||||
|
||||
public ContextBuffers() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/) { super((Pointer)null); allocate(rPointer, sPointer, aPointer, isOwner); }
|
||||
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer, @Cast("bool") boolean isOwner/*=false*/);
|
||||
public ContextBuffers(Pointer rPointer, Pointer sPointer, Pointer aPointer) { super((Pointer)null); allocate(rPointer, sPointer, aPointer); }
|
||||
private native void allocate(Pointer rPointer, Pointer sPointer, Pointer aPointer);
|
||||
|
||||
public native Pointer reductionBuffer();
|
||||
public native Pointer scalarBuffer();
|
||||
public native Pointer allocationBuffer();
|
||||
|
||||
public native Pointer execStream();
|
||||
public native Pointer specialStream();
|
||||
|
||||
public native void setReductionBuffer(Pointer pointer);
|
||||
public native void setScalarBuffer(Pointer pointer);
|
||||
public native void setAllocationBuffer(Pointer pointer);
|
||||
|
||||
public native void triggerOwnership(@Cast("bool") boolean isOwner);
|
||||
|
||||
public native int deviceId();
|
||||
}
|
||||
|
||||
|
||||
|
||||
// #endif //DEV_TESTS_CONTEXTBUFFERS_H
|
||||
|
||||
|
||||
// Parsed from execution/LaunchContext.h
|
||||
|
||||
/*******************************************************************************
|
||||
|
@ -22872,6 +22938,9 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
|
||||
public static native LaunchContext defaultContext();
|
||||
|
||||
|
||||
public static native void swapContextBuffers(@ByRef ContextBuffers buffers);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -100,6 +100,7 @@ import java.util.Scanner;
|
|||
"ops/declarable/headers/bitwise.h",
|
||||
"ops/declarable/headers/loss.h",
|
||||
"ops/declarable/headers/datatypes.h",
|
||||
"execution/ContextBuffers.h",
|
||||
"execution/LaunchContext.h",
|
||||
"array/ShapeDescriptor.h",
|
||||
"array/TadDescriptor.h",
|
||||
|
|
|
@ -198,27 +198,43 @@ public class SpecialTests extends BaseNd4jTest {
|
|||
|
||||
val list = new CopyOnWriteArrayList<INDArray>();
|
||||
val threads = new ArrayList<Thread>();
|
||||
for (int e = 0; e< Nd4j.getAffinityManager().getNumberOfDevices(); e++) {
|
||||
val devices = Nd4j.getAffinityManager().getNumberOfDevices();
|
||||
for (int e = 0; e < devices; e++) {
|
||||
val f = e;
|
||||
val t = new Thread(new Runnable() {
|
||||
@Override
|
||||
public void run() {
|
||||
val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||
log.info("Current device: {}", deviceId);
|
||||
for (int i = 0; i < 10; i++) {
|
||||
list.add(Nd4j.create(100, 100).assign(1.0f));
|
||||
val ar = Nd4j.create(100, 100).assign(1.0f);
|
||||
|
||||
assertEquals(deviceId, Nd4j.getAffinityManager().getDeviceForArray(ar));
|
||||
list.add(ar);
|
||||
Nd4j.getExecutioner().commit();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
t.start();
|
||||
t.join();
|
||||
threads.add(t);
|
||||
|
||||
log.info("------------------------");
|
||||
}
|
||||
|
||||
for (val t:threads)
|
||||
t.join();
|
||||
|
||||
for (val a:list)
|
||||
assertEquals(1.0f, a.meanNumber().floatValue(), 1e-5);
|
||||
for (val a:list) {
|
||||
val device = Nd4j.getAffinityManager().getDeviceForArray(a);
|
||||
try {
|
||||
assertEquals(1.0f, a.meanNumber().floatValue(), 1e-5);
|
||||
} catch (Exception e) {
|
||||
log.error("Failed for array from device [{}]", device);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
Loading…
Reference in New Issue