- string NDArray flat serde impl + tests (#163)
- string NDArray equalsTo impl Signed-off-by: raver119 <raver119@gmail.com>master
parent
a9b08cc163
commit
b091e972ef
|
@ -476,19 +476,36 @@ std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() {
|
|||
////////////////////////////////////////////////////////////////////////
|
||||
std::vector<int8_t> NDArray::asByteVector() {
|
||||
|
||||
std::vector<int8_t> result((unsigned long long) this->lengthOf() * sizeOfT());
|
||||
|
||||
if (this->isView()) {
|
||||
auto tmp = this->dup(this->ordering());
|
||||
|
||||
memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
||||
if (isS()) {
|
||||
// string data type requires special treatment
|
||||
syncToHost();
|
||||
auto numWords = this->lengthOf();
|
||||
auto offsetsBuffer = this->bufferAsT<Nd4jLong>();
|
||||
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords);
|
||||
auto dataLength = offsetsBuffer[numWords];
|
||||
std::vector<int8_t> result(headerLength + dataLength);
|
||||
|
||||
delete tmp;
|
||||
memcpy(result.data(), getBuffer(), headerLength + dataLength);
|
||||
|
||||
return result;
|
||||
} else {
|
||||
// all other types are linear
|
||||
std::vector<int8_t> result((unsigned long long) this->lengthOf() * sizeOfT());
|
||||
|
||||
if (this->isView()) {
|
||||
auto tmp = this->dup(this->ordering());
|
||||
syncToHost();
|
||||
memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
||||
|
||||
delete tmp;
|
||||
} else {
|
||||
syncToHost();
|
||||
memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
else {
|
||||
memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1584,9 +1601,7 @@ std::string* NDArray::bufferAsT() const {
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
T* NDArray::bufferAsT() const {
|
||||
if (isS())
|
||||
throw std::runtime_error("You can't use this method on String array");
|
||||
|
||||
// FIXME: do we REALLY want sync here?
|
||||
syncToHost();
|
||||
|
||||
return reinterpret_cast<T*>(getBuffer());
|
||||
|
@ -3202,20 +3217,39 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
|
|||
} else if (!shape::equalsSoft(getShapeInfo(), other->getShapeInfo()))
|
||||
return false;
|
||||
|
||||
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
|
||||
if (isS()) {
|
||||
// string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length
|
||||
for (int e = 0; e < this->lengthOf(); e++) {
|
||||
auto s1 = this->e<std::string>(e);
|
||||
auto s2 = other->e<std::string>(e);
|
||||
|
||||
ExtraArguments extras({eps});
|
||||
if (s1 != s2)
|
||||
return false;
|
||||
}
|
||||
|
||||
NDArray::prepareSpecialUse({&tmp}, {this, other});
|
||||
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo());
|
||||
NDArray::registerSpecialUse({&tmp}, {this, other});
|
||||
return true;
|
||||
} else {
|
||||
// regular numeric types
|
||||
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
|
||||
|
||||
synchronize("NDArray::equalsTo");
|
||||
ExtraArguments extras({eps});
|
||||
|
||||
if (tmp.e<int>(0) > 0)
|
||||
return false;
|
||||
NDArray::prepareSpecialUse({&tmp}, {this, other});
|
||||
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(),
|
||||
getSpecialBuffer(), getSpecialShapeInfo(),
|
||||
extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(),
|
||||
other->getShapeInfo(), other->getSpecialBuffer(),
|
||||
other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(),
|
||||
tmp.specialBuffer(), tmp.specialShapeInfo());
|
||||
NDArray::registerSpecialUse({&tmp}, {this, other});
|
||||
|
||||
return true;
|
||||
synchronize("NDArray::equalsTo");
|
||||
|
||||
if (tmp.e<int>(0) > 0)
|
||||
return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -54,6 +54,7 @@
|
|||
#include <graph/ExecutionResult.h>
|
||||
#include <exceptions/graph_execution_exception.h>
|
||||
#include <exceptions/no_results_exception.h>
|
||||
#include <graph/FlatUtils.h>
|
||||
|
||||
namespace nd4j{
|
||||
namespace graph {
|
||||
|
@ -575,15 +576,9 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
|||
continue;
|
||||
|
||||
|
||||
NDArray* array = var->getNDArray();
|
||||
auto byteVector = array->asByteVector();
|
||||
auto array = var->getNDArray();
|
||||
|
||||
auto fBuffer = builder.CreateVector(byteVector);
|
||||
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
|
||||
|
||||
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
||||
|
||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array->dataType()), bo);
|
||||
auto fArray = FlatUtils::toFlatArray(builder, *array);
|
||||
|
||||
auto fName = builder.CreateString(*(var->getName()));
|
||||
auto id = CreateIntPair(builder, var->id(), var->index());
|
||||
|
|
|
@ -36,6 +36,8 @@ namespace nd4j {
|
|||
static std::pair<Nd4jLong, Nd4jLong> fromLongPair(LongPair* pair);
|
||||
|
||||
static NDArray* fromFlatArray(const nd4j::graph::FlatArray* flatArray);
|
||||
|
||||
static flatbuffers::Offset<FlatArray> toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array);
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
@ -102,5 +102,16 @@ namespace nd4j {
|
|||
delete[] newShape;
|
||||
return array;
|
||||
}
|
||||
|
||||
flatbuffers::Offset<FlatArray> FlatUtils::toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array) {
|
||||
auto byteVector = array.asByteVector();
|
||||
|
||||
auto fBuffer = builder.CreateVector(byteVector);
|
||||
auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector());
|
||||
|
||||
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
||||
|
||||
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,100 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <NDArray.h>
|
||||
#include <NDArrayFactory.h>
|
||||
#include "testlayers.h"
|
||||
#include <graph/Stash.h>
|
||||
#include <FlatUtils.h>
|
||||
|
||||
using namespace nd4j;
|
||||
|
||||
class FlatUtilsTests : public testing::Test {
|
||||
public:
|
||||
|
||||
};
|
||||
|
||||
TEST_F(FlatUtilsTests, flat_float_serde_1) {
|
||||
auto array = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||
builder.Finish(flatArray);
|
||||
|
||||
|
||||
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||
|
||||
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||
|
||||
ASSERT_EQ(array, *restored);
|
||||
|
||||
delete restored;
|
||||
}
|
||||
|
||||
TEST_F(FlatUtilsTests, flat_int_serde_1) {
|
||||
auto array = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||
builder.Finish(flatArray);
|
||||
|
||||
|
||||
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||
|
||||
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||
|
||||
ASSERT_EQ(array, *restored);
|
||||
|
||||
delete restored;
|
||||
}
|
||||
|
||||
TEST_F(FlatUtilsTests, flat_bool_serde_1) {
|
||||
auto array = NDArrayFactory::create<bool>('c', {4}, {true, false, true, false});
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||
builder.Finish(flatArray);
|
||||
|
||||
|
||||
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||
|
||||
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||
|
||||
ASSERT_EQ(array, *restored);
|
||||
|
||||
delete restored;
|
||||
}
|
||||
|
||||
TEST_F(FlatUtilsTests, flat_string_serde_1) {
|
||||
auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder(1024);
|
||||
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||
builder.Finish(flatArray);
|
||||
|
||||
|
||||
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||
|
||||
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||
|
||||
ASSERT_EQ(array, *restored);
|
||||
|
||||
delete restored;
|
||||
}
|
|
@ -24,7 +24,6 @@
|
|||
#include "testlayers.h"
|
||||
#include <graph/Stash.h>
|
||||
|
||||
using namespace nd4j;
|
||||
using namespace nd4j;
|
||||
|
||||
class StringTests : public testing::Test {
|
||||
|
@ -91,4 +90,4 @@ TEST_F(StringTests, Basic_dup_1) {
|
|||
ASSERT_EQ(f, z1);
|
||||
|
||||
delete dup;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue