[WIP] Shugeo release fixes4 (#91)
* Fixed fake_quant_with_min_max_vars op. * Refactored bitcast op. * bad linspace removed Signed-off-by: raver119 <raver119@gmail.com> * Corrected tests for bitcast op. * Eliminated debug prints. * one fix Signed-off-by: raver119 <raver119@gmail.com> * one fix Signed-off-by: raver119 <raver119@gmail.com> * Added a pair of comments.master
parent
d19eeaec52
commit
dc66a52bc7
|
@ -116,6 +116,8 @@ class ND4J_EXPORT DataBuffer {
|
||||||
void setToZeroBuffers(const bool both = false);
|
void setToZeroBuffers(const bool both = false);
|
||||||
|
|
||||||
void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0);
|
void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0);
|
||||||
|
|
||||||
|
static void memcpy(const DataBuffer &dst, const DataBuffer &src);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,6 @@ void DataBuffer::setCountersToZero() {
|
||||||
void DataBuffer::copyCounters(const DataBuffer& other) {
|
void DataBuffer::copyCounters(const DataBuffer& other) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate primary buffer only (cpu case)
|
void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate primary buffer only (cpu case)
|
||||||
|
|
||||||
|
@ -49,7 +48,7 @@ void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinByte
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if(other._primaryBuffer != nullptr)
|
if(other._primaryBuffer != nullptr)
|
||||||
memcpy(static_cast<int8_t*>(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast<const int8_t*>(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes);
|
std::memcpy(static_cast<int8_t*>(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast<const int8_t*>(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -61,7 +60,7 @@ void DataBuffer::copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinB
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if(hostBuffer != nullptr)
|
if(hostBuffer != nullptr)
|
||||||
memcpy(static_cast<int8_t*>(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast<const int8_t*>(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes);
|
std::memcpy(static_cast<int8_t*>(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast<const int8_t*>(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -100,6 +99,13 @@ void DataBuffer::allocateSpecial() {
|
||||||
void DataBuffer::migrate() {
|
void DataBuffer::migrate() {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
///////////////////////////////////////////////////////////////////////
|
||||||
|
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
|
||||||
|
if (src._lenInBytes < dst._lenInBytes)
|
||||||
|
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination");
|
||||||
|
|
||||||
|
std::memcpy(dst._primaryBuffer, src._primaryBuffer, dst._lenInBytes);
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void DataBuffer::writePrimary() const { }
|
void DataBuffer::writePrimary() const { }
|
||||||
|
|
|
@ -97,6 +97,19 @@ void DataBuffer::copyCounters(const DataBuffer& other) {
|
||||||
_readPrimary.store(other._writeSpecial);
|
_readPrimary.store(other._writeSpecial);
|
||||||
_readSpecial.store(other._writePrimary);
|
_readSpecial.store(other._writePrimary);
|
||||||
}
|
}
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
|
||||||
|
if (src._lenInBytes < dst._lenInBytes)
|
||||||
|
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination");
|
||||||
|
|
||||||
|
if (src.isSpecialActual()) {
|
||||||
|
cudaMemcpy(dst._specialBuffer, src._specialBuffer, dst.getLenInBytes(), cudaMemcpyDeviceToDevice);
|
||||||
|
} else if (src.isPrimaryActual()) {
|
||||||
|
cudaMemcpy(dst._specialBuffer, src._primaryBuffer, dst.getLenInBytes(), cudaMemcpyHostToDevice);
|
||||||
|
}
|
||||||
|
|
||||||
|
dst.writeSpecial();
|
||||||
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer
|
void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer
|
||||||
|
|
|
@ -45,9 +45,10 @@ namespace nd4j {
|
||||||
REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty.");
|
REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty.");
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
// buffers for both input and output should be equals
|
|
||||||
DataBuffer buf(input->buffer(), input->specialBuffer(), input->lengthOf() * input->sizeOfT(), input->dataType());
|
// just memcpy data
|
||||||
*(output->dataBuffer()) = buf;
|
// output->dataBuffer()->copyBufferFrom(*input->dataBuffer()); // as variant
|
||||||
|
DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); // this is modern approach
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -282,6 +282,60 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests15, Test_BitCast_5) {
|
||||||
|
auto x = NDArrayFactory::create<float16>('c', {4, 4}, {
|
||||||
|
0.4922f, 0.2969f, 0.6172f, 0.8906f,
|
||||||
|
0.9297f, 0.0859f, 0.2344f, 0.3828f,
|
||||||
|
0.5781f, 0.7969f, 0.0391f, 0.1719f,
|
||||||
|
0.8359f, 0.9297f, 0.3438f, 0.0938f});
|
||||||
|
|
||||||
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL,
|
||||||
|
3314989625590692528LL});
|
||||||
|
nd4j::ops::bitcast op;
|
||||||
|
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
auto res = result->at(0);
|
||||||
|
// res->printIndexedBuffer("BITCAST5");
|
||||||
|
ASSERT_TRUE(e.equalsTo(res));
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests15, Test_BitCast_6) {
|
||||||
|
auto x = NDArrayFactory::create<float16>('c', {4, 4}, {
|
||||||
|
1.f, 2.f, 3.f, 4.f,
|
||||||
|
5.f, 6.f, 7.f, 8.f,
|
||||||
|
9.f, 10.f, 11.f, 12.f,
|
||||||
|
13.f, 14.f, 15.f, 16.f});
|
||||||
|
|
||||||
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL,
|
||||||
|
5476460161268730496LL});
|
||||||
|
nd4j::ops::bitcast op;
|
||||||
|
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
auto res = result->at(0);
|
||||||
|
// res->printIndexedBuffer("BITCAST6");
|
||||||
|
ASSERT_TRUE(e.equalsTo(res));
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
TEST_F(DeclarableOpsTests15, Test_BitCast_7) {
|
||||||
|
auto x = NDArrayFactory::create<float16>('c', {4, 4}, {
|
||||||
|
1.1f, 2.2f, 3.3f, 4.4f,
|
||||||
|
5.1f, 6.2f, 7.3f, 8.4f,
|
||||||
|
9.1f, 10.2f, 11.3f, 12.4f,
|
||||||
|
13.f, 14.2f, 15.3f, 16.4f});
|
||||||
|
|
||||||
|
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {
|
||||||
|
4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL});
|
||||||
|
nd4j::ops::bitcast op;
|
||||||
|
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
auto res = result->at(0);
|
||||||
|
// res->printIndexedBuffer("BITCAST7");
|
||||||
|
ASSERT_TRUE(e.equalsTo(res));
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
|
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
|
||||||
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
|
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
|
||||||
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});
|
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});
|
||||||
|
|
Loading…
Reference in New Issue