/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include #if NOT_EXCLUDED(OP_split_string) #include #include namespace nd4j { namespace ops { CUSTOM_OP_IMPL(compat_string_split, 2, 2, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto delim = INPUT_VARIABLE(1); auto indices = OUTPUT_VARIABLE(0); auto values = OUTPUT_VARIABLE(1); auto d = delim->e(0); input->syncToHost(); delim->syncToHost(); // output rank N+1 wrt input rank std::vector ocoords(input->rankOf() + 1); std::vector icoords(input->rankOf()); // getting buffer lengths // FIXME: it'll be bigger, since it'll include delimiters, auto outputLength = StringUtils::byteLength(*input); uint64_t ss = 0L; Nd4jLong ic = 0L; // loop through each string within tensor for (auto e = 0L; e < input->lengthOf(); e++) { // now we should map substring to indices auto s = input->e(e); // getting base index shape::index2coords(e, input->shapeInfo(), icoords.data()); // getting number of substrings auto cnt = StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1; // filling output indices for (uint64_t f = 0; f < cnt; f++) { for (auto v: icoords) indices->p(ic++, v); // last index indices->p(ic++, f); } ss += cnt; } // process strings now std::vector strings; for (auto e = 0L; e < input->lengthOf(); e++) { auto split = StringUtils::split(input->e(e), d); for (const auto &s:split) strings.emplace_back(s); } // now once we have all strings in single vector time to fill auto tmp = NDArrayFactory::string('c', {(Nd4jLong) strings.size()}, strings); auto blen = StringUtils::byteLength(tmp) + ShapeUtils::stringBufferHeaderRequirements(strings.size()); // for CUDA mostly values->dataBuffer()->allocatePrimary(); values->dataBuffer()->expand(blen); memcpy(values->buffer(), tmp.buffer(), blen); values->tickWriteHost(); // special case, for future use indices->syncToDevice(); values->syncToDevice(); // we have to tick buffers values->dataBuffer()->writePrimary(); values->dataBuffer()->readSpecial(); return Status::OK(); }; DECLARE_SHAPE_FN(compat_string_split) { auto input = INPUT_VARIABLE(0); auto delim = INPUT_VARIABLE(1); auto d = delim->e(0); // count number of delimiter substrings in all strings within input tensor uint64_t cnt = 0; for (auto e = 0L; e < input->lengthOf(); e++) { // FIXME: bad, not UTF-compatible auto s = input->e(e); // each substring we see in haystack, splits string in two parts. so we should add 1 to the number of subarrays cnt += StringUtils::countSubarrays(s.c_str(), s.length(), d.c_str(), d.length()) + 1; } // shape calculations // virtual tensor rank will be N+1, for N rank input array, where data will be located at the biggest dimension // values tensor is going to be vector always // indices tensor is going to be vector with length equal to values.length * output rank auto valuesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt, nd4j::DataType::UTF8); auto indicesShape = ConstantShapeHelper::getInstance()->vectorShapeInfo(cnt * (input->rankOf() + 1), nd4j::DataType::INT64); return SHAPELIST(indicesShape, valuesShape); } DECLARE_TYPES(compat_string_split) { getOpDescriptor() ->setAllowedInputTypes({ALL_STRINGS}) ->setAllowedOutputTypes(0, {ALL_INDICES}) ->setAllowedOutputTypes(1, {ALL_STRINGS}); } } } #endif