- string NDArray flat serde impl + tests (#163)

- string NDArray equalsTo impl

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-24 14:16:34 +03:00 committed by GitHub
parent a9b08cc163
commit b091e972ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 172 additions and 31 deletions

View File

@ -476,20 +476,37 @@ std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() {
////////////////////////////////////////////////////////////////////////
std::vector<int8_t> NDArray::asByteVector() {
if (isS()) {
// string data type requires special treatment
syncToHost();
auto numWords = this->lengthOf();
auto offsetsBuffer = this->bufferAsT<Nd4jLong>();
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords);
auto dataLength = offsetsBuffer[numWords];
std::vector<int8_t> result(headerLength + dataLength);
memcpy(result.data(), getBuffer(), headerLength + dataLength);
return result;
} else {
// all other types are linear
std::vector<int8_t> result((unsigned long long) this->lengthOf() * sizeOfT());
if (this->isView()) {
auto tmp = this->dup(this->ordering());
syncToHost();
memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
delete tmp;
}
else {
} else {
syncToHost();
memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
}
return result;
}
}
//////////////////////////////////////////////////////////////////////////
void NDArray::linspace(const double start) {
@ -1584,9 +1601,7 @@ std::string* NDArray::bufferAsT() const {
//////////////////////////////////////////////////////////////////////////
template <typename T>
T* NDArray::bufferAsT() const {
if (isS())
throw std::runtime_error("You can't use this method on String array");
// FIXME: do we REALLY want sync here?
syncToHost();
return reinterpret_cast<T*>(getBuffer());
@ -3202,12 +3217,30 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
} else if (!shape::equalsSoft(getShapeInfo(), other->getShapeInfo()))
return false;
if (isS()) {
// string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length
for (int e = 0; e < this->lengthOf(); e++) {
auto s1 = this->e<std::string>(e);
auto s2 = other->e<std::string>(e);
if (s1 != s2)
return false;
}
return true;
} else {
// regular numeric types
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
ExtraArguments extras({eps});
NDArray::prepareSpecialUse({&tmp}, {this, other});
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo());
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(),
getSpecialBuffer(), getSpecialShapeInfo(),
extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(),
other->getShapeInfo(), other->getSpecialBuffer(),
other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(),
tmp.specialBuffer(), tmp.specialShapeInfo());
NDArray::registerSpecialUse({&tmp}, {this, other});
synchronize("NDArray::equalsTo");
@ -3217,6 +3250,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
return true;
}
}
//////////////////////////////////////////////////////////////////////////
template <>

View File

@ -54,6 +54,7 @@
#include <graph/ExecutionResult.h>
#include <exceptions/graph_execution_exception.h>
#include <exceptions/no_results_exception.h>
#include <graph/FlatUtils.h>
namespace nd4j{
namespace graph {
@ -575,15 +576,9 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
continue;
NDArray* array = var->getNDArray();
auto byteVector = array->asByteVector();
auto array = var->getNDArray();
auto fBuffer = builder.CreateVector(byteVector);
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
auto fArray = CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array->dataType()), bo);
auto fArray = FlatUtils::toFlatArray(builder, *array);
auto fName = builder.CreateString(*(var->getName()));
auto id = CreateIntPair(builder, var->id(), var->index());

View File

@ -36,6 +36,8 @@ namespace nd4j {
static std::pair<Nd4jLong, Nd4jLong> fromLongPair(LongPair* pair);
static NDArray* fromFlatArray(const nd4j::graph::FlatArray* flatArray);
static flatbuffers::Offset<FlatArray> toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array);
};
}
}

View File

@ -102,5 +102,16 @@ namespace nd4j {
delete[] newShape;
return array;
}
flatbuffers::Offset<FlatArray> FlatUtils::toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array) {
auto byteVector = array.asByteVector();
auto fBuffer = builder.CreateVector(byteVector);
auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector());
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
}
}
}

View File

@ -0,0 +1,100 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <NDArray.h>
#include <NDArrayFactory.h>
#include "testlayers.h"
#include <graph/Stash.h>
#include <FlatUtils.h>
using namespace nd4j;
class FlatUtilsTests : public testing::Test {
public:
};
TEST_F(FlatUtilsTests, flat_float_serde_1) {
auto array = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array);
builder.Finish(flatArray);
auto pfArray = GetFlatArray(builder.GetBufferPointer());
auto restored = FlatUtils::fromFlatArray(pfArray);
ASSERT_EQ(array, *restored);
delete restored;
}
TEST_F(FlatUtilsTests, flat_int_serde_1) {
auto array = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array);
builder.Finish(flatArray);
auto pfArray = GetFlatArray(builder.GetBufferPointer());
auto restored = FlatUtils::fromFlatArray(pfArray);
ASSERT_EQ(array, *restored);
delete restored;
}
TEST_F(FlatUtilsTests, flat_bool_serde_1) {
auto array = NDArrayFactory::create<bool>('c', {4}, {true, false, true, false});
flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array);
builder.Finish(flatArray);
auto pfArray = GetFlatArray(builder.GetBufferPointer());
auto restored = FlatUtils::fromFlatArray(pfArray);
ASSERT_EQ(array, *restored);
delete restored;
}
TEST_F(FlatUtilsTests, flat_string_serde_1) {
auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
flatbuffers::FlatBufferBuilder builder(1024);
auto flatArray = FlatUtils::toFlatArray(builder, array);
builder.Finish(flatArray);
auto pfArray = GetFlatArray(builder.GetBufferPointer());
auto restored = FlatUtils::fromFlatArray(pfArray);
ASSERT_EQ(array, *restored);
delete restored;
}

View File

@ -24,7 +24,6 @@
#include "testlayers.h"
#include <graph/Stash.h>
using namespace nd4j;
using namespace nd4j;
class StringTests : public testing::Test {