cavis/libnd4j/include/ops/declarable/generic/compat/compat_string_split.cpp

140 lines
5.2 KiB
C++

/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_split_string)
#include <ops/declarable/CustomOperations.h>
#include <helpers/StringUtils.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(compat_string_split, 2, 2, false, 0, 0) {
auto input = INPUT_VARIABLE(0);
auto delim = INPUT_VARIABLE(1);
auto indices = OUTPUT_VARIABLE(0);
auto values = OUTPUT_VARIABLE(1);
auto d = delim->e<std::string>(0);
input->syncToHost();
delim->syncToHost();
// output rank N+1 wrt input rank
std::vector<Nd4jLong> ocoords(input->rankOf() + 1);
std::vector<Nd4jLong> icoords(input->rankOf());
// getting buffer lengths
// FIXME: it'll be bigger, since it'll include delimiters,
auto outputLength = StringUtils::byteLength(*input);
uint64_t ss = 0L;
Nd4jLong ic = 0L;
// loop through each string within tensor
for (auto e = 0L; e < input->lengthOf(); e++) {
// now we should map substring to indices
auto s = input->e<std::string>(e);
// getting base index
shape::index2coords(e, input->shapeInfo(), icoords.data());
// getting number of substrings
auto cnt = StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1;
// filling output indices
for (uint64_t f = 0; f < cnt; f++) {
for (auto v: icoords)
indices->p(ic++, v);
// last index
indices->p(ic++, f);
}
ss += cnt;
}
// process strings now
std::vector<std::string> strings;
for (auto e = 0L; e < input->lengthOf(); e++) {
auto split = StringUtils::split(input->e<std::string>(e), d);
for (const auto &s:split)
strings.emplace_back(s);
}
// now once we have all strings in single vector time to fill
auto tmp = NDArrayFactory::string('c', {(Nd4jLong) strings.size()}, strings);
auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size());
// for CUDA mostly
values->dataBuffer()->allocatePrimary();
values->dataBuffer()->expand(blen);
memcpy(values->buffer(), tmp.buffer(), blen);
values->tickWriteHost();
// special case, for future use
indices->syncToDevice();
values->syncToDevice();
// we have to tick buffers
values->dataBuffer()->writePrimary();
values->dataBuffer()->readSpecial();
return Status::OK();
};
DECLARE_SHAPE_FN(compat_string_split) {
auto input = INPUT_VARIABLE(0);
auto delim = INPUT_VARIABLE(1);
auto d = delim->e<std::string>(0);
// count number of delimiter substrings in all strings within input tensor
uint64_t cnt = 0;
for (auto e = 0L; e < input->lengthOf(); e++) {
// FIXME: bad, not UTF-compatible
auto s = input->e<std::string>(e);
// each substring we see in haystack, splits string in two parts. so we should add 1 to the number of subarrays
cnt += StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1;
}
// shape calculations
// virtual tensor rank will be N+1, for N rank input array, where data will be located at the biggest dimension
// values tensor is going to be vector always
// indices tensor is going to be vector with length equal to values.length * output rank
auto valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt, nd4j::DataType::UTF8);
auto indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt * (input->rankOf() + 1), nd4j::DataType::INT64);
return SHAPELIST(indicesShape, valuesShape);
}
DECLARE_TYPES(compat_string_split) {
getOpDescriptor()
->setAllowedInputTypes({ALL_STRINGS})
->setAllowedOutputTypes(0, {ALL_INDICES})
->setAllowedOutputTypes(1, {ALL_STRINGS});
}
}
}
#endif