/*******************************************************************************
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

//
// @author raver119@gmail.com
//


#include <ops/declarable/helpers/sparse_to_dense.h>
#include <helpers/StringUtils.h>
#include <helpers/ShapeUtils.h>

namespace sd {
    namespace ops {
        namespace helpers {
            template <typename X, typename I>
            static void fill_(const void *vvalues, const void *vindices, void *voutput, const Nd4jLong *zShapeInfo, uint8_t rank, uint64_t length) {
                auto values = reinterpret_cast<const X*>(vvalues);
                auto indices = reinterpret_cast<const I*>(vindices);
                auto output = reinterpret_cast<X*>(voutput);

                int coords[MAX_RANK];
                uint64_t pos = 0;
                for (uint64_t e = 0L; e < length; e++) {
                    // indices come in blocks
                    for (uint8_t p = 0; p < rank; p++) {
                        coords[p] = indices[pos++];
                    }

                    // fill output at given coords with sparse value
                    output[shape::getOffset(zShapeInfo, coords)] = values[e];
                }

            }

            void compat_sparse_to_dense(const NDArray &values, const NDArray &indices, NDArray *def, NDArray &output) {
                // make sure host buffer is updated
                values.syncToHost();
                indices.syncToHost();
                output.syncToHost();

                auto rank = output.rankOf();

                if (output.isS()) {
                    // string case is not so trivial, since elements might, and probably will, have different sizes
                    auto numValues = values.lengthOf();
                    auto numElements = output.lengthOf();

                    // first of all we calculate final buffer sizes and offsets
                    auto defaultLength = def == nullptr ? 0 : StringUtils::byteLength(*def);
                    auto valuesLength = StringUtils::byteLength(values);
                    auto bufferLength = defaultLength * (output.lengthOf() - numValues) + valuesLength;
                    auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numElements);

                    // now we make sure our output buffer can hold results
                    output.dataBuffer()->expand( bufferLength + headerLength);

                    std::vector<Nd4jLong> outputCoords(rank);
                    std::vector<Nd4jLong> valueCoords(rank);

                    auto offsetsBuffer = output.bufferAsT<Nd4jLong>();
                    auto dataBuffer = reinterpret_cast<uint8_t*>(offsetsBuffer + output.lengthOf());

                    offsetsBuffer[0] = 0;

                    // getting initial value coords
                    for (int e = 0; e < rank; e++)
                        valueCoords[e] = indices.e<Nd4jLong>(e);

                    // write results individually
                    for (Nd4jLong e = 0; e < numElements; e++) {
                        auto vIndex = shape::coords2index(output.shapeInfo(), valueCoords.data());
                        auto cLength = 0L;
                        std::string str;
                        if (vIndex == e) {
                            // we're writing down sparse value here
                             str = values.e<std::string>(e);
                        } else {
                            // we're writing down default value if it exists
                            if (def != nullptr)
                                str = def->e<std::string>(0);
                            else
                                str = "";
                        }

                        // TODO: make it unicode compliant
                        memcpy(&dataBuffer[offsetsBuffer[e]], str.c_str(), str.length());

                        // writing down offset
                        offsetsBuffer[e+1] = cLength;
                    }
                } else {
                    // numeric case is trivial, since all elements have equal sizes

                    // write out default values, if they are present
                    if (def != nullptr) {
                        output.assign(def);

                        // make sure output is synced back
                        output.syncToHost();
                    }

                    // write out values
                    BUILD_DOUBLE_SELECTOR(values.dataType(), indices.dataType(), fill_, (values.buffer(), indices.buffer(), output.buffer(), output.shapeInfo(), rank, values.lengthOf()), LIBND4J_TYPES, INDEXING_TYPES);
                }
                // copy back to device, if there's any
                output.syncToDevice();
            }
        }
    }
}