diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index 1404afc96..fdbcae49f 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -476,19 +476,36 @@ std::vector NDArray::getShapeInfoAsVector() { //////////////////////////////////////////////////////////////////////// std::vector NDArray::asByteVector() { - std::vector result((unsigned long long) this->lengthOf() * sizeOfT()); - if (this->isView()) { - auto tmp = this->dup(this->ordering()); - memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); + if (isS()) { + // string data type requires special treatment + syncToHost(); + auto numWords = this->lengthOf(); + auto offsetsBuffer = this->bufferAsT(); + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords); + auto dataLength = offsetsBuffer[numWords]; + std::vector result(headerLength + dataLength); - delete tmp; + memcpy(result.data(), getBuffer(), headerLength + dataLength); + + return result; + } else { + // all other types are linear + std::vector result((unsigned long long) this->lengthOf() * sizeOfT()); + + if (this->isView()) { + auto tmp = this->dup(this->ordering()); + syncToHost(); + memcpy(result.data(), tmp->getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); + + delete tmp; + } else { + syncToHost(); + memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); + } + return result; } - else { - memcpy(result.data(), getBuffer(), (unsigned long long) lengthOf() * sizeOfT()); - } - return result; } ////////////////////////////////////////////////////////////////////////// @@ -1584,9 +1601,7 @@ std::string* NDArray::bufferAsT() const { ////////////////////////////////////////////////////////////////////////// template T* NDArray::bufferAsT() const { - if (isS()) - throw std::runtime_error("You can't use this method on String array"); - + // FIXME: do we REALLY want sync here? syncToHost(); return reinterpret_cast(getBuffer()); @@ -3202,20 +3217,39 @@ bool NDArray::equalsTo(const NDArray *other, double eps) const { } else if (!shape::equalsSoft(getShapeInfo(), other->getShapeInfo())) return false; - NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0 + if (isS()) { + // string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length + for (int e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); - ExtraArguments extras({eps}); + if (s1 != s2) + return false; + } - NDArray::prepareSpecialUse({&tmp}, {this, other}); - NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(), other->getShapeInfo(), other->getSpecialBuffer(), other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo()); - NDArray::registerSpecialUse({&tmp}, {this, other}); + return true; + } else { + // regular numeric types + NDArray tmp(nd4j::DataType::FLOAT32, getContext()); // scalar = 0 - synchronize("NDArray::equalsTo"); + ExtraArguments extras({eps}); - if (tmp.e(0) > 0) - return false; + NDArray::prepareSpecialUse({&tmp}, {this, other}); + NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, getBuffer(), getShapeInfo(), + getSpecialBuffer(), getSpecialShapeInfo(), + extras.argumentsAsT(DataType::FLOAT32), other->getBuffer(), + other->getShapeInfo(), other->getSpecialBuffer(), + other->getSpecialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo()); + NDArray::registerSpecialUse({&tmp}, {this, other}); - return true; + synchronize("NDArray::equalsTo"); + + if (tmp.e(0) > 0) + return false; + + return true; + } } ////////////////////////////////////////////////////////////////////////// diff --git a/libnd4j/blas/cpu/GraphExecutioner.cpp b/libnd4j/blas/cpu/GraphExecutioner.cpp index b5e7d9bf2..6f97bc024 100644 --- a/libnd4j/blas/cpu/GraphExecutioner.cpp +++ b/libnd4j/blas/cpu/GraphExecutioner.cpp @@ -54,6 +54,7 @@ #include #include #include +#include namespace nd4j{ namespace graph { @@ -575,15 +576,9 @@ Nd4jStatus GraphExecutioner::execute(Graph *graph, VariableSpace* variableSpace) continue; - NDArray* array = var->getNDArray(); - auto byteVector = array->asByteVector(); + auto array = var->getNDArray(); - auto fBuffer = builder.CreateVector(byteVector); - auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); - - auto bo = static_cast(BitwiseUtils::asByteOrder()); - - auto fArray = CreateFlatArray(builder, fShape, fBuffer, static_cast(array->dataType()), bo); + auto fArray = FlatUtils::toFlatArray(builder, *array); auto fName = builder.CreateString(*(var->getName())); auto id = CreateIntPair(builder, var->id(), var->index()); diff --git a/libnd4j/include/graph/FlatUtils.h b/libnd4j/include/graph/FlatUtils.h index abfff5915..939db1fb7 100644 --- a/libnd4j/include/graph/FlatUtils.h +++ b/libnd4j/include/graph/FlatUtils.h @@ -36,6 +36,8 @@ namespace nd4j { static std::pair fromLongPair(LongPair* pair); static NDArray* fromFlatArray(const nd4j::graph::FlatArray* flatArray); + + static flatbuffers::Offset toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array); }; } } diff --git a/libnd4j/include/graph/impl/FlatUtils.cpp b/libnd4j/include/graph/impl/FlatUtils.cpp index ad0c5112d..bc8ff7e33 100644 --- a/libnd4j/include/graph/impl/FlatUtils.cpp +++ b/libnd4j/include/graph/impl/FlatUtils.cpp @@ -102,5 +102,16 @@ namespace nd4j { delete[] newShape; return array; } + + flatbuffers::Offset FlatUtils::toFlatArray(flatbuffers::FlatBufferBuilder &builder, NDArray &array) { + auto byteVector = array.asByteVector(); + + auto fBuffer = builder.CreateVector(byteVector); + auto fShape = builder.CreateVector(array.getShapeInfoAsFlatVector()); + + auto bo = static_cast(BitwiseUtils::asByteOrder()); + + return CreateFlatArray(builder, fShape, fBuffer, static_cast(array.dataType()), bo); + } } } \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp b/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp new file mode 100644 index 000000000..bf428b833 --- /dev/null +++ b/libnd4j/tests_cpu/layers_tests/FlatUtilsTests.cpp @@ -0,0 +1,100 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include "testlayers.h" +#include +#include + +using namespace nd4j; + +class FlatUtilsTests : public testing::Test { +public: + +}; + +TEST_F(FlatUtilsTests, flat_float_serde_1) { + auto array = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + + + auto pfArray = GetFlatArray(builder.GetBufferPointer()); + + auto restored = FlatUtils::fromFlatArray(pfArray); + + ASSERT_EQ(array, *restored); + + delete restored; +} + +TEST_F(FlatUtilsTests, flat_int_serde_1) { + auto array = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + + + auto pfArray = GetFlatArray(builder.GetBufferPointer()); + + auto restored = FlatUtils::fromFlatArray(pfArray); + + ASSERT_EQ(array, *restored); + + delete restored; +} + +TEST_F(FlatUtilsTests, flat_bool_serde_1) { + auto array = NDArrayFactory::create('c', {4}, {true, false, true, false}); + + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + + + auto pfArray = GetFlatArray(builder.GetBufferPointer()); + + auto restored = FlatUtils::fromFlatArray(pfArray); + + ASSERT_EQ(array, *restored); + + delete restored; +} + +TEST_F(FlatUtilsTests, flat_string_serde_1) { + auto array = NDArrayFactory::string('c', {3}, {"alpha", "beta", "gamma"}); + + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + + + auto pfArray = GetFlatArray(builder.GetBufferPointer()); + + auto restored = FlatUtils::fromFlatArray(pfArray); + + ASSERT_EQ(array, *restored); + + delete restored; +} \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/StringTests.cpp b/libnd4j/tests_cpu/layers_tests/StringTests.cpp index a023dcdd3..2ae236210 100644 --- a/libnd4j/tests_cpu/layers_tests/StringTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/StringTests.cpp @@ -24,7 +24,6 @@ #include "testlayers.h" #include -using namespace nd4j; using namespace nd4j; class StringTests : public testing::Test { @@ -91,4 +90,4 @@ TEST_F(StringTests, Basic_dup_1) { ASSERT_EQ(f, z1); delete dup; -} +} \ No newline at end of file