[WIP] Shugeo release fixes4 (#91)
* Fixed fake_quant_with_min_max_vars op. * Refactored bitcast op. * bad linspace removed Signed-off-by: raver119 <raver119@gmail.com> * Corrected tests for bitcast op. * Eliminated debug prints. * one fix Signed-off-by: raver119 <raver119@gmail.com> * one fix Signed-off-by: raver119 <raver119@gmail.com> * Added a pair of comments.master
parent
d19eeaec52
commit
dc66a52bc7
|
@ -116,6 +116,8 @@ class ND4J_EXPORT DataBuffer {
|
|||
void setToZeroBuffers(const bool both = false);
|
||||
|
||||
void copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes = 0, const Nd4jLong offsetThis = 0, const Nd4jLong offsetOther = 0);
|
||||
|
||||
static void memcpy(const DataBuffer &dst, const DataBuffer &src);
|
||||
};
|
||||
|
||||
|
||||
|
|
|
@ -33,7 +33,6 @@ void DataBuffer::setCountersToZero() {
|
|||
void DataBuffer::copyCounters(const DataBuffer& other) {
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void DataBuffer::allocateBuffers(const bool allocBoth) { // always allocate primary buffer only (cpu case)
|
||||
|
||||
|
@ -49,7 +48,7 @@ void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinByte
|
|||
return;
|
||||
|
||||
if(other._primaryBuffer != nullptr)
|
||||
memcpy(static_cast<int8_t*>(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast<const int8_t*>(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes);
|
||||
std::memcpy(static_cast<int8_t*>(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast<const int8_t*>(other._primaryBuffer) + offsetOther * DataTypeUtils::sizeOfElement(other._dataType), sizeToCopyinBytes);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
|
@ -61,7 +60,7 @@ void DataBuffer::copyBufferFromHost(const void* hostBuffer, size_t sizeToCopyinB
|
|||
return;
|
||||
|
||||
if(hostBuffer != nullptr)
|
||||
memcpy(static_cast<int8_t*>(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast<const int8_t*>(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes);
|
||||
std::memcpy(static_cast<int8_t*>(_primaryBuffer) + offsetThis * DataTypeUtils::sizeOfElement(_dataType), static_cast<const int8_t*>(hostBuffer) + offsetHostBuffer * DataTypeUtils::sizeOfElement(_dataType), sizeToCopyinBytes);
|
||||
}
|
||||
|
||||
|
||||
|
@ -100,6 +99,13 @@ void DataBuffer::allocateSpecial() {
|
|||
void DataBuffer::migrate() {
|
||||
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////////
|
||||
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
|
||||
if (src._lenInBytes < dst._lenInBytes)
|
||||
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination");
|
||||
|
||||
std::memcpy(dst._primaryBuffer, src._primaryBuffer, dst._lenInBytes);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void DataBuffer::writePrimary() const { }
|
||||
|
|
|
@ -97,6 +97,19 @@ void DataBuffer::copyCounters(const DataBuffer& other) {
|
|||
_readPrimary.store(other._writeSpecial);
|
||||
_readSpecial.store(other._writePrimary);
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void DataBuffer::memcpy(const DataBuffer &dst, const DataBuffer &src) {
|
||||
if (src._lenInBytes < dst._lenInBytes)
|
||||
throw std::runtime_error("DataBuffer::memcpy: Source data buffer is smaller than destination");
|
||||
|
||||
if (src.isSpecialActual()) {
|
||||
cudaMemcpy(dst._specialBuffer, src._specialBuffer, dst.getLenInBytes(), cudaMemcpyDeviceToDevice);
|
||||
} else if (src.isPrimaryActual()) {
|
||||
cudaMemcpy(dst._specialBuffer, src._primaryBuffer, dst.getLenInBytes(), cudaMemcpyHostToDevice);
|
||||
}
|
||||
|
||||
dst.writeSpecial();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
void DataBuffer::copyBufferFrom(const DataBuffer& other, size_t sizeToCopyinBytes, const Nd4jLong offsetThis, const Nd4jLong offsetOther) { // copies only to special buffer
|
||||
|
|
|
@ -45,9 +45,10 @@ namespace nd4j {
|
|||
REQUIRE_TRUE(output->isEmpty(), 0, "BITCAST: If input is empty, output array must also be empty.");
|
||||
return Status::OK();
|
||||
}
|
||||
// buffers for both input and output should be equals
|
||||
DataBuffer buf(input->buffer(), input->specialBuffer(), input->lengthOf() * input->sizeOfT(), input->dataType());
|
||||
*(output->dataBuffer()) = buf;
|
||||
|
||||
// just memcpy data
|
||||
// output->dataBuffer()->copyBufferFrom(*input->dataBuffer()); // as variant
|
||||
DataBuffer::memcpy(*output->dataBuffer(), *input->dataBuffer()); // this is modern approach
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -282,6 +282,60 @@ TEST_F(DeclarableOpsTests15, Test_BitCast_4) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_BitCast_5) {
|
||||
auto x = NDArrayFactory::create<float16>('c', {4, 4}, {
|
||||
0.4922f, 0.2969f, 0.6172f, 0.8906f,
|
||||
0.9297f, 0.0859f, 0.2344f, 0.3828f,
|
||||
0.5781f, 0.7969f, 0.0391f, 0.1719f,
|
||||
0.8359f, 0.9297f, 0.3438f, 0.0938f});
|
||||
|
||||
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL,
|
||||
3314989625590692528LL});
|
||||
nd4j::ops::bitcast op;
|
||||
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
auto res = result->at(0);
|
||||
// res->printIndexedBuffer("BITCAST5");
|
||||
ASSERT_TRUE(e.equalsTo(res));
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_BitCast_6) {
|
||||
auto x = NDArrayFactory::create<float16>('c', {4, 4}, {
|
||||
1.f, 2.f, 3.f, 4.f,
|
||||
5.f, 6.f, 7.f, 8.f,
|
||||
9.f, 10.f, 11.f, 12.f,
|
||||
13.f, 14.f, 15.f, 16.f});
|
||||
|
||||
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL,
|
||||
5476460161268730496LL});
|
||||
nd4j::ops::bitcast op;
|
||||
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
auto res = result->at(0);
|
||||
// res->printIndexedBuffer("BITCAST6");
|
||||
ASSERT_TRUE(e.equalsTo(res));
|
||||
delete result;
|
||||
}
|
||||
TEST_F(DeclarableOpsTests15, Test_BitCast_7) {
|
||||
auto x = NDArrayFactory::create<float16>('c', {4, 4}, {
|
||||
1.1f, 2.2f, 3.3f, 4.4f,
|
||||
5.1f, 6.2f, 7.3f, 8.4f,
|
||||
9.1f, 10.2f, 11.3f, 12.4f,
|
||||
13.f, 14.2f, 15.3f, 16.4f});
|
||||
|
||||
auto e = NDArrayFactory::create<Nd4jLong>('c', {4}, {
|
||||
4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL});
|
||||
nd4j::ops::bitcast op;
|
||||
auto result = op.execute({&x}, {}, {nd4j::DataType::INT64}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
auto res = result->at(0);
|
||||
// res->printIndexedBuffer("BITCAST7");
|
||||
ASSERT_TRUE(e.equalsTo(res));
|
||||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, Test_depthwise_bp_1) {
|
||||
auto in = NDArrayFactory::create<float>('c', {4, 8, 64, 64});
|
||||
auto w = NDArrayFactory::create<float>('c', {2, 2, 8, 2});
|
||||
|
@ -609,4 +663,4 @@ TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) {
|
|||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
ASSERT_EQ(true, z.e<bool>(0));
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue