- string NDArray flat serde impl + tests (#163)
- string NDArray equalsTo impl Signed-off-by: raver119 <raver119@gmail.com>master
parent
a9b08cc163
commit
b091e972ef
|
@ -476,20 +476,37 @@ std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() {
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
std::vector<int8_t> NDArray::asByteVector() {
|
std::vector<int8_t> NDArray::asByteVector() {
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if (isS()) {
|
||||||
|
// string data type requires special treatment
|
||||||
|
syncToHost();
|
||||||
|
auto numWords = this->lengthOf();
|
||||||
|
auto offsetsBuffer = this->bufferAsT<Nd4jLong>();
|
||||||
|
auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords);
|
||||||
|
auto dataLength = offsetsBuffer[numWords];
|
||||||
|
std::vector<int8_t> result(headerLength + dataLength);
|
||||||
|
|
||||||
|
memcpy(result.data(), getBuffer(), headerLength + dataLength);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
} else {
|
||||||
|
// all other types are linear
|
||||||
std::vector<int8_t> result((unsigned long long) this->lengthOf() * sizeOfT());
|
std::vector<int8_t> result((unsigned long long) this->lengthOf() * sizeOfT());
|
||||||
|
|
||||||
if (this->isView()) {
|
if (this->isView()) {
|
||||||
auto tmp = this->dup(this->ordering());
|
auto tmp = this->dup(this->ordering());
|
||||||
|
syncToHost();
|
||||||
memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
||||||
|
|
||||||
delete tmp;
|
delete tmp;
|
||||||
}
|
} else {
|
||||||
else {
|
syncToHost();
|
||||||
memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT());
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::linspace(const double start) {
|
void NDArray::linspace(const double start) {
|
||||||
|
@ -1584,9 +1601,7 @@ std::string* NDArray::bufferAsT() const {
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T* NDArray::bufferAsT() const {
|
T* NDArray::bufferAsT() const {
|
||||||
if (isS())
|
// FIXME: do we REALLY want sync here?
|
||||||
throw std::runtime_error("You can't use this method on String array");
|
|
||||||
|
|
||||||
syncToHost();
|
syncToHost();
|
||||||
|
|
||||||
return reinterpret_cast<T*>(getBuffer());
|
return reinterpret_cast<T*>(getBuffer());
|
||||||
|
@ -3202,12 +3217,30 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
|
||||||
} else if (!shape::equalsSoft(getShapeInfo(), other->getShapeInfo()))
|
} else if (!shape::equalsSoft(getShapeInfo(), other->getShapeInfo()))
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
|
if (isS()) {
|
||||||
|
// string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length
|
||||||
|
for (int e = 0; e < this->lengthOf(); e++) {
|
||||||
|
auto s1 = this->e<std::string>(e);
|
||||||
|
auto s2 = other->e<std::string>(e);
|
||||||
|
|
||||||
|
if (s1 != s2)
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// regular numeric types
|
||||||
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
|
NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0
|
||||||
|
|
||||||
ExtraArguments extras({eps});
|
ExtraArguments extras({eps});
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&tmp}, {this, other});
|
NDArray::prepareSpecialUse({&tmp}, {this, other});
|
||||||
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo());
|
NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(),
|
||||||
|
getSpecialBuffer(), getSpecialShapeInfo(),
|
||||||
|
extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(),
|
||||||
|
other->getShapeInfo(), other->getSpecialBuffer(),
|
||||||
|
other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(),
|
||||||
|
tmp.specialBuffer(), tmp.specialShapeInfo());
|
||||||
NDArray::registerSpecialUse({&tmp}, {this, other});
|
NDArray::registerSpecialUse({&tmp}, {this, other});
|
||||||
|
|
||||||
synchronize("NDArray::equalsTo");
|
synchronize("NDArray::equalsTo");
|
||||||
|
@ -3217,6 +3250,7 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const {
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
template <>
|
template <>
|
||||||
|
|
|
@ -54,6 +54,7 @@
|
||||||
#include <graph/ExecutionResult.h>
|
#include <graph/ExecutionResult.h>
|
||||||
#include <exceptions/graph_execution_exception.h>
|
#include <exceptions/graph_execution_exception.h>
|
||||||
#include <exceptions/no_results_exception.h>
|
#include <exceptions/no_results_exception.h>
|
||||||
|
#include <graph/FlatUtils.h>
|
||||||
|
|
||||||
namespace nd4j{
|
namespace nd4j{
|
||||||
namespace graph {
|
namespace graph {
|
||||||
|
@ -575,15 +576,9 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace)
|
||||||
continue;
|
continue;
|
||||||
|
|
||||||
|
|
||||||
NDArray* array = var->getNDArray();
|
auto array = var->getNDArray();
|
||||||
auto byteVector = array->asByteVector();
|
|
||||||
|
|
||||||
auto fBuffer = builder.CreateVector(byteVector);
|
auto fArray = FlatUtils::toFlatArray(builder, *array);
|
||||||
auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector());
|
|
||||||
|
|
||||||
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
|
||||||
|
|
||||||
auto fArray = CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array->dataType()), bo);
|
|
||||||
|
|
||||||
auto fName = builder.CreateString(*(var->getName()));
|
auto fName = builder.CreateString(*(var->getName()));
|
||||||
auto id = CreateIntPair(builder, var->id(), var->index());
|
auto id = CreateIntPair(builder, var->id(), var->index());
|
||||||
|
|
|
@ -36,6 +36,8 @@ namespace nd4j {
|
||||||
static std::pair<Nd4jLong, Nd4jLong> fromLongPair(LongPair* pair);
|
static std::pair<Nd4jLong, Nd4jLong> fromLongPair(LongPair* pair);
|
||||||
|
|
||||||
static NDArray* fromFlatArray(const nd4j::graph::FlatArray* flatArray);
|
static NDArray* fromFlatArray(const nd4j::graph::FlatArray* flatArray);
|
||||||
|
|
||||||
|
static flatbuffers::Offset<FlatArray> toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -102,5 +102,16 @@ namespace nd4j {
|
||||||
delete[] newShape;
|
delete[] newShape;
|
||||||
return array;
|
return array;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flatbuffers::Offset<FlatArray> FlatUtils::toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array) {
|
||||||
|
auto byteVector = array.asByteVector();
|
||||||
|
|
||||||
|
auto fBuffer = builder.CreateVector(byteVector);
|
||||||
|
auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector());
|
||||||
|
|
||||||
|
auto bo = static_cast<nd4j::graph::ByteOrder>(BitwiseUtils::asByteOrder());
|
||||||
|
|
||||||
|
return CreateFlatArray(builder, fShape, fBuffer, static_cast<nd4j::graph::DataType>(array.dataType()), bo);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -0,0 +1,100 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <NDArray.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
#include "testlayers.h"
|
||||||
|
#include <graph/Stash.h>
|
||||||
|
#include <FlatUtils.h>
|
||||||
|
|
||||||
|
using namespace nd4j;
|
||||||
|
|
||||||
|
class FlatUtilsTests : public testing::Test {
|
||||||
|
public:
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_F(FlatUtilsTests, flat_float_serde_1) {
|
||||||
|
auto array = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 3.f, 4.f});
|
||||||
|
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||||
|
builder.Finish(flatArray);
|
||||||
|
|
||||||
|
|
||||||
|
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||||
|
|
||||||
|
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||||
|
|
||||||
|
ASSERT_EQ(array, *restored);
|
||||||
|
|
||||||
|
delete restored;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FlatUtilsTests, flat_int_serde_1) {
|
||||||
|
auto array = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
||||||
|
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||||
|
builder.Finish(flatArray);
|
||||||
|
|
||||||
|
|
||||||
|
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||||
|
|
||||||
|
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||||
|
|
||||||
|
ASSERT_EQ(array, *restored);
|
||||||
|
|
||||||
|
delete restored;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FlatUtilsTests, flat_bool_serde_1) {
|
||||||
|
auto array = NDArrayFactory::create<bool>('c', {4}, {true, false, true, false});
|
||||||
|
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||||
|
builder.Finish(flatArray);
|
||||||
|
|
||||||
|
|
||||||
|
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||||
|
|
||||||
|
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||||
|
|
||||||
|
ASSERT_EQ(array, *restored);
|
||||||
|
|
||||||
|
delete restored;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(FlatUtilsTests, flat_string_serde_1) {
|
||||||
|
auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"});
|
||||||
|
|
||||||
|
flatbuffers::FlatBufferBuilder builder(1024);
|
||||||
|
auto flatArray = FlatUtils::toFlatArray(builder, array);
|
||||||
|
builder.Finish(flatArray);
|
||||||
|
|
||||||
|
|
||||||
|
auto pfArray = GetFlatArray(builder.GetBufferPointer());
|
||||||
|
|
||||||
|
auto restored = FlatUtils::fromFlatArray(pfArray);
|
||||||
|
|
||||||
|
ASSERT_EQ(array, *restored);
|
||||||
|
|
||||||
|
delete restored;
|
||||||
|
}
|
|
@ -24,7 +24,6 @@
|
||||||
#include "testlayers.h"
|
#include "testlayers.h"
|
||||||
#include <graph/Stash.h>
|
#include <graph/Stash.h>
|
||||||
|
|
||||||
using namespace nd4j;
|
|
||||||
using namespace nd4j;
|
using namespace nd4j;
|
||||||
|
|
||||||
class StringTests : public testing::Test {
|
class StringTests : public testing::Test {
|
||||||
|
|
Loading…
Reference in New Issue