[WIP] Shugeo release fixes4 (#91)

* Fixed fake_quant_with_min_max_vars op.

* Refactored bitcast op.

* bad linspace removed

Signed-off-by: raver119 <raver119@gmail.com>

* Corrected tests for bitcast op.

* Eliminated debug prints.

* one fix

Signed-off-by: raver119 <raver119@gmail.com>

* one fix

Signed-off-by: raver119 <raver119@gmail.com>

* Added a pair of comments.
master
shugeo 2019-11-29 15:05:08 +02:00 committed by raver119
parent d19eeaec52
commit dc66a52bc7
5 changed files with 83 additions and 7 deletions

View File

@ -116,6 +116,8 @@ class ND4J_EXPORT DataBuffer {
void setToZeroBuffers(const bool both = false);
void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0);
static void memcpy(const DataBuffer &dst, const DataBuffer &src);
};

View File

@ -33,7 +33,6 @@ void DataBuffer::setCountersToZero() {
void DataBuffer::copyCounters(const DataBuffer& other) {
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate primary buffer only (cpu case)
@ -49,7 +48,7 @@ void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinByte
return;
if(other._primaryBuffer != nullptr)
memcpy(static_cast<int8_t*>(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast<const int8_t*>(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes);
std::memcpy(static_cast<int8_t*>(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast<const int8_t*>(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes);
}
////////////////////////////////////////////////////////////////////////
@ -61,7 +60,7 @@ void DataBuffer::copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinB
return;
if(hostBuffer != nullptr)
memcpy(static_cast<int8_t*>(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast<const int8_t*>(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes);
std::memcpy(static_cast<int8_t*>(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast<const int8_t*>(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes);
}
@ -100,6 +99,13 @@ void DataBuffer::allocateSpecial() {
void DataBuffer::migrate() {
}
///////////////////////////////////////////////////////////////////////
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
if (src._lenInBytes < dst._lenInBytes)
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination");
std::memcpy(dst._primaryBuffer, src._primaryBuffer, dst._lenInBytes);
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::writePrimary() const { }

View File

@ -97,6 +97,19 @@ void DataBuffer::copyCounters(const DataBuffer& other) {
_readPrimary.store(other._writeSpecial);
_readSpecial.store(other._writePrimary);
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
if (src._lenInBytes < dst._lenInBytes)
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination");
if (src.isSpecialActual()) {
cudaMemcpy(dst._specialBuffer, src._specialBuffer, dst.getLenInBytes(), cudaMemcpyDeviceToDevice);
} else if (src.isPrimaryActual()) {
cudaMemcpy(dst._specialBuffer, src._primaryBuffer, dst.getLenInBytes(), cudaMemcpyHostToDevice);
}
dst.writeSpecial();
}
////////////////////////////////////////////////////////////////////////
void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer

View File

@ -45,9 +45,10 @@ namespace nd4j {
REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty.");
return Status::OK();
}
// buffers for both input and output should be equals
DataBuffer buf(input->buffer(), input->specialBuffer(), input->lengthOf() * input->sizeOfT(), input->dataType());
*(output->dataBuffer()) = buf;
// just memcpy data
// output->dataBuffer()->copyBufferFrom(*input->dataBuffer()); // as variant
DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); // this is modern approach
return Status::OK();
}

View File

@ -282,6 +282,60 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4) {
}
TEST_F(DeclarableOpsTests15, Test_BitCast_5) {
auto x = NDArrayFactory::create<float16>('c', {4, 4}, {
0.4922f, 0.2969f, 0.6172f, 0.8906f,
0.9297f, 0.0859f, 0.2344f, 0.3828f,
0.5781f, 0.7969f, 0.0391f, 0.1719f,
0.8359f, 0.9297f, 0.3438f, 0.0938f});
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL,
3314989625590692528LL});
nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
ASSERT_EQ(Status::OK(), result->status());
auto res = result->at(0);
// res->printIndexedBuffer("BITCAST5");
ASSERT_TRUE(e.equalsTo(res));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_BitCast_6) {
auto x = NDArrayFactory::create<float16>('c', {4, 4}, {
1.f, 2.f, 3.f, 4.f,
5.f, 6.f, 7.f, 8.f,
9.f, 10.f, 11.f, 12.f,
13.f, 14.f, 15.f, 16.f});
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL,
5476460161268730496LL});
nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
ASSERT_EQ(Status::OK(), result->status());
auto res = result->at(0);
// res->printIndexedBuffer("BITCAST6");
ASSERT_TRUE(e.equalsTo(res));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_BitCast_7) {
auto x = NDArrayFactory::create<float16>('c', {4, 4}, {
1.1f, 2.2f, 3.3f, 4.4f,
5.1f, 6.2f, 7.3f, 8.4f,
9.1f, 10.2f, 11.3f, 12.4f,
13.f, 14.2f, 15.3f, 16.4f});
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {
4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL});
nd4j::ops::bitcast op;
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
ASSERT_EQ(Status::OK(), result->status());
auto res = result->at(0);
// res->printIndexedBuffer("BITCAST7");
ASSERT_TRUE(e.equalsTo(res));
delete result;
}
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});