/* * ****************************************************************************** * * * * * * This program and the accompanying materials are made available under the * * terms of the Apache License, Version 2.0 which is available at * * https://www.apache.org/licenses/LICENSE-2.0. * * * * See the NOTICE file distributed with this work for additional * * information regarding copyright ownership. * * Unless required by applicable law or agreed to in writing, software * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * * License for the specific language governing permissions and limitations * * under the License. * * * * SPDX-License-Identifier: Apache-2.0 * ***************************************************************************** */ // // @author raver119@gmail.com // #include #include #include namespace sd { namespace ops { namespace helpers { template static void fill_(const void *vvalues, const void *vindices, void *voutput, const Nd4jLong *zShapeInfo, uint8_t rank, uint64_t length) { auto values = reinterpret_cast(vvalues); auto indices = reinterpret_cast(vindices); auto output = reinterpret_cast(voutput); int coords[MAX_RANK]; uint64_t pos = 0; for (uint64_t e = 0L; e < length; e++) { // indices come in blocks for (uint8_t p = 0; p < rank; p++) { coords[p] = indices[pos++]; } // fill output at given coords with sparse value output[shape::getOffset(zShapeInfo, coords)] = values[e]; } } void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output) { // make sure host buffer is updated values.syncToHost(); indices.syncToHost(); output.syncToHost(); auto rank = output.rankOf(); if (output.isS()) { // string case is not so trivial, since elements might, and probably will, have different sizes auto numValues = values.lengthOf(); auto numElements = output.lengthOf(); // first of all we calculate final buffer sizes and offsets auto defaultLength = def == nullptr ? 0 : StringUtils::byteLength(*def); auto valuesLength = StringUtils::byteLength(values); auto bufferLength = defaultLength * (output.lengthOf() - numValues) + valuesLength; auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numElements); // now we make sure our output buffer can hold results output.dataBuffer()->expand( bufferLength + headerLength); std::vector outputCoords(rank); std::vector valueCoords(rank); auto offsetsBuffer = output.bufferAsT(); auto dataBuffer = reinterpret_cast(offsetsBuffer + output.lengthOf()); offsetsBuffer[0] = 0; // getting initial value coords for (int e = 0; e < rank; e++) valueCoords[e] = indices.e(e); // write results individually for (Nd4jLong e = 0; e < numElements; e++) { auto vIndex = shape::coords2index(output.shapeInfo(), valueCoords.data()); auto cLength = 0L; std::string str; if (vIndex == e) { // we're writing down sparse value here str = values.e(e); } else { // we're writing down default value if it exists if (def != nullptr) str = def->e(0); else str = ""; } // TODO: make it unicode compliant memcpy(&dataBuffer[offsetsBuffer[e]], str.c_str(), str.length()); // writing down offset offsetsBuffer[e+1] = cLength; } } else { // numeric case is trivial, since all elements have equal sizes // write out default values, if they are present if (def != nullptr) { output.assign(def); // make sure output is synced back output.syncToHost(); } // write out values BUILD_DOUBLE_SELECTOR(values.dataType(), indices.dataType(), fill_, (values.buffer(), indices.buffer(), output.buffer(), output.shapeInfo(), rank, values.lengthOf()), LIBND4J_TYPES, INDEXING_TYPES); } // copy back to device, if there's any output.syncToDevice(); } } } }