From dc66a52bc7c6fcda7d37cd9fa1b9f85abaa673db Mon Sep 17 00:00:00 2001 From: shugeo Date: Fri, 29 Nov 2019 15:05:08 +0200 Subject: [PATCH] [WIP] Shugeo release fixes4 (#91) * Fixed fake_quant_with_min_max_vars op. * Refactored bitcast op. * bad linspace removed Signed-off-by: raver119 * Corrected tests for bitcast op. * Eliminated debug prints. * one fix Signed-off-by: raver119 * one fix Signed-off-by: raver119 * Added a pair of comments. --- libnd4j/include/array/DataBuffer.h | 2 + libnd4j/include/array/cpu/DataBuffer.cpp | 12 +++- libnd4j/include/array/cuda/DataBuffer.cu | 13 +++++ .../declarable/generic/datatypes/bitcast.cpp | 7 ++- .../layers_tests/DeclarableOpsTests15.cpp | 56 ++++++++++++++++++- 5 files changed, 83 insertions(+), 7 deletions(-) diff --git a/libnd4j/include/array/DataBuffer.h b/libnd4j/include/array/DataBuffer.h index 37d575b13..034f16a25 100644 --- a/libnd4j/include/array/DataBuffer.h +++ b/libnd4j/include/array/DataBuffer.h @@ -116,6 +116,8 @@ class ND4J_EXPORT DataBuffer { void setToZeroBuffers(const bool both = false); void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0); + + static void memcpy(const DataBuffer &dst, const DataBuffer &src); }; diff --git a/libnd4j/include/array/cpu/DataBuffer.cpp b/libnd4j/include/array/cpu/DataBuffer.cpp index 5d27bf9a1..d13ca0def 100644 --- a/libnd4j/include/array/cpu/DataBuffer.cpp +++ b/libnd4j/include/array/cpu/DataBuffer.cpp @@ -33,7 +33,6 @@ void DataBuffer::setCountersToZero() { void DataBuffer::copyCounters(const DataBuffer& other) { } - //////////////////////////////////////////////////////////////////////// void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate primary buffer only (cpu case) @@ -49,7 +48,7 @@ void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinByte return; if(other._primaryBuffer != nullptr) - memcpy(static_cast(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes); + std::memcpy(static_cast(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes); } //////////////////////////////////////////////////////////////////////// @@ -61,7 +60,7 @@ void DataBuffer::copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinB return; if(hostBuffer != nullptr) - memcpy(static_cast(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes); + std::memcpy(static_cast(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes); } @@ -100,6 +99,13 @@ void DataBuffer::allocateSpecial() { void DataBuffer::migrate() { } +/////////////////////////////////////////////////////////////////////// +void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { + if (src._lenInBytes < dst._lenInBytes) + throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination"); + + std::memcpy(dst._primaryBuffer, src._primaryBuffer, dst._lenInBytes); +} //////////////////////////////////////////////////////////////////////// void DataBuffer::writePrimary() const { } diff --git a/libnd4j/include/array/cuda/DataBuffer.cu b/libnd4j/include/array/cuda/DataBuffer.cu index e71ed4b49..5cb227e69 100644 --- a/libnd4j/include/array/cuda/DataBuffer.cu +++ b/libnd4j/include/array/cuda/DataBuffer.cu @@ -97,6 +97,19 @@ void DataBuffer::copyCounters(const DataBuffer& other) { _readPrimary.store(other._writeSpecial); _readSpecial.store(other._writePrimary); } +//////////////////////////////////////////////////////////////////////// +void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) { + if (src._lenInBytes < dst._lenInBytes) + throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination"); + + if (src.isSpecialActual()) { + cudaMemcpy(dst._specialBuffer, src._specialBuffer, dst.getLenInBytes(), cudaMemcpyDeviceToDevice); + } else if (src.isPrimaryActual()) { + cudaMemcpy(dst._specialBuffer, src._primaryBuffer, dst.getLenInBytes(), cudaMemcpyHostToDevice); + } + + dst.writeSpecial(); +} //////////////////////////////////////////////////////////////////////// void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer diff --git a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp index 533b4e2f9..24f96f7a7 100644 --- a/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp +++ b/libnd4j/include/ops/declarable/generic/datatypes/bitcast.cpp @@ -45,9 +45,10 @@ namespace nd4j { REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty."); return Status::OK(); } - // buffers for both input and output should be equals - DataBuffer buf(input->buffer(), input->specialBuffer(), input->lengthOf() * input->sizeOfT(), input->dataType()); - *(output->dataBuffer()) = buf; + + // just memcpy data +// output->dataBuffer()->copyBufferFrom(*input->dataBuffer()); // as variant + DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); // this is modern approach return Status::OK(); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp index 50f8de9f0..fdaa7b549 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -282,6 +282,60 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4) { } + +TEST_F(DeclarableOpsTests15, Test_BitCast_5) { + auto x = NDArrayFactory::create('c', {4, 4}, { + 0.4922f, 0.2969f, 0.6172f, 0.8906f, + 0.9297f, 0.0859f, 0.2344f, 0.3828f, + 0.5781f, 0.7969f, 0.0391f, 0.1719f, + 0.8359f, 0.9297f, 0.3438f, 0.0938f}); + + auto e = NDArrayFactory::create('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL, + 3314989625590692528LL}); + nd4j::ops::bitcast op; + auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto res = result->at(0); +// res->printIndexedBuffer("BITCAST5"); + ASSERT_TRUE(e.equalsTo(res)); + delete result; +} + +TEST_F(DeclarableOpsTests15, Test_BitCast_6) { + auto x = NDArrayFactory::create('c', {4, 4}, { + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f}); + + auto e = NDArrayFactory::create('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL, + 5476460161268730496LL}); + nd4j::ops::bitcast op; + auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto res = result->at(0); +// res->printIndexedBuffer("BITCAST6"); + ASSERT_TRUE(e.equalsTo(res)); + delete result; +} +TEST_F(DeclarableOpsTests15, Test_BitCast_7) { + auto x = NDArrayFactory::create('c', {4, 4}, { + 1.1f, 2.2f, 3.3f, 4.4f, + 5.1f, 6.2f, 7.3f, 8.4f, + 9.1f, 10.2f, 11.3f, 12.4f, + 13.f, 14.2f, 15.3f, 16.4f}); + + auto e = NDArrayFactory::create('c', {4}, { + 4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL}); + nd4j::ops::bitcast op; + auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result->status()); + auto res = result->at(0); +// res->printIndexedBuffer("BITCAST7"); + ASSERT_TRUE(e.equalsTo(res)); + delete result; +} + TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) { auto in = NDArrayFactory::create('c', {4, 8, 64, 64}); auto w = NDArrayFactory::create('c', {2, 2, 8, 2}); @@ -609,4 +663,4 @@ TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) { ASSERT_EQ(Status::OK(), status); ASSERT_EQ(true, z.e(0)); -} \ No newline at end of file +}