diff --git a/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/JavaSourceArgDescriptorSource.java b/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/JavaSourceArgDescriptorSource.java index 5d67599a9..d4216e925 100644 --- a/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/JavaSourceArgDescriptorSource.java +++ b/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/JavaSourceArgDescriptorSource.java @@ -551,6 +551,31 @@ public class JavaSourceArgDescriptorSource implements ArgDescriptorSource { } } + + + if(name.contains("fill")) { + argDescriptorProposals.add(ArgDescriptorProposal.builder() + .sourceOfProposal("java") + .proposalWeight(Double.MAX_VALUE) + .descriptor(OpNamespace.ArgDescriptor.newBuilder() + .setArgType(OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR) + .setName("shape") + .setIsArray(false) + .setArgIndex(0) + .build()).build()); + + argDescriptorProposals.add(ArgDescriptorProposal.builder() + .sourceOfProposal("java") + .proposalWeight(Double.MAX_VALUE) + .descriptor(OpNamespace.ArgDescriptor.newBuilder() + .setArgType(OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR) + .setName("result") + .setIsArray(false) + .setArgIndex(1) + .build()).build()); + + } + if(name.contains("loop_cond")) { argDescriptorProposals.add(ArgDescriptorProposal.builder() .sourceOfProposal("java") diff --git a/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/Libnd4jArgDescriptorSource.java b/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/Libnd4jArgDescriptorSource.java index fdb9ba46e..fb234c1fa 100644 --- a/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/Libnd4jArgDescriptorSource.java +++ b/contrib/codegen-tools/libnd4j-gen/src/main/java/org/nd4j/descriptor/proposal/impl/Libnd4jArgDescriptorSource.java @@ -855,6 +855,63 @@ public class Libnd4jArgDescriptorSource implements ArgDescriptorSource { .build()).build()); } + if(name.contains("fill")) { + argDescriptorProposals.add(ArgDescriptorProposal.builder() + .sourceOfProposal("java") + .proposalWeight(Double.MAX_VALUE) + .descriptor(OpNamespace.ArgDescriptor.newBuilder() + .setArgType(OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR) + .setName("shape") + .setIsArray(false) + .setArgIndex(0) + .build()).build()); + + argDescriptorProposals.add(ArgDescriptorProposal.builder() + .sourceOfProposal("java") + .proposalWeight(Double.MAX_VALUE) + .descriptor(OpNamespace.ArgDescriptor.newBuilder() + .setArgType(OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR) + .setName("result") + .setIsArray(false) + .setArgIndex(1) + .build()).build()); + + } + + if(name.contains("unsorted_")) { + argDescriptorProposals.add(ArgDescriptorProposal.builder() + .sourceOfProposal("c++") + .proposalWeight(Double.MAX_VALUE) + .descriptor(OpNamespace.ArgDescriptor.newBuilder() + .setArgType(OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR) + .setName("input") + .setIsArray(false) + .setArgIndex(0) + .build()).build()); + + argDescriptorProposals.add(ArgDescriptorProposal.builder() + .sourceOfProposal("c++") + .proposalWeight(Double.MAX_VALUE) + .descriptor(OpNamespace.ArgDescriptor.newBuilder() + .setArgType(OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR) + .setName("idxSegments") + .setIsArray(false) + .setArgIndex(1) + .build()).build()); + + argDescriptorProposals.add(ArgDescriptorProposal.builder() + .sourceOfProposal("c++") + .proposalWeight(Double.MAX_VALUE) + .descriptor(OpNamespace.ArgDescriptor.newBuilder() + .setArgType(OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR) + .setName("numSegments") + .setIsArray(false) + .setArgIndex(2) + .build()).build()); + + + } + if(name.equals("lin_space")) { argDescriptorProposals.add(ArgDescriptorProposal.builder() diff --git a/libnd4j/include/array/NDArray.hXX b/libnd4j/include/array/NDArray.hXX index 6ec86a8da..530cf50c0 100644 --- a/libnd4j/include/array/NDArray.hXX +++ b/libnd4j/include/array/NDArray.hXX @@ -32,189 +32,189 @@ namespace sd { -template <> -ND4J_EXPORT utf8string NDArray::e(const Nd4jLong i) const; -template <> -ND4J_EXPORT std::string NDArray::e(const Nd4jLong i) const; -template <> -ND4J_EXPORT std::u16string NDArray::e(const Nd4jLong i) const; -template <> -ND4J_EXPORT std::u32string NDArray::e(const Nd4jLong i) const; + template <> + ND4J_EXPORT utf8string NDArray::e(const Nd4jLong i) const; + template <> + ND4J_EXPORT std::string NDArray::e(const Nd4jLong i) const; + template <> + ND4J_EXPORT std::u16string NDArray::e(const Nd4jLong i) const; + template <> + ND4J_EXPORT std::u32string NDArray::e(const Nd4jLong i) const; //////////////////////////////////////////////////////////////////////// // copy constructor -NDArray::NDArray(const NDArray& other) { + NDArray::NDArray(const NDArray& other) { - _context = other._context; - _offset = 0; + _context = other._context; + _offset = 0; - setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), other.shapeOf(), other.rankOf())); + setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), other.shapeOf(), other.rankOf())); - if(!isEmpty()) { - _buffer = std::make_shared(other.lengthOf() * other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); - this->assign(&other); + if(!isEmpty()) { + _buffer = std::make_shared(other.lengthOf() * other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); + this->assign(&other); + } + else + _buffer = std::make_shared(); } - else - _buffer = std::make_shared(); -} //////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext * context) { + NDArray::NDArray(const char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext * context) { - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("Rank of NDArray can't exceed 32"); + if ((int) shape.size() > MAX_RANK) + throw std::invalid_argument("Rank of NDArray can't exceed 32"); - _context = context; - _isAttached = _context->getWorkspace() != nullptr; - _offset = 0; + _context = context; + _isAttached = _context->getWorkspace() != nullptr; + _offset = 0; - if (shape.empty()) - setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); - else - setShapeInfo(ShapeDescriptor(dtype, order, shape)); - - _buffer = std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace()); - _buffer->setToZeroBuffers(); -} - -//////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const char order, const std::vector &shape, const std::vector& data, sd::DataType dtype, sd::LaunchContext * context) { - - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("Rank of NDArray can't exceed 32"); - - _context = context; - _offset = 0; - - if (shape.size() == 0) { - if (data.size() == 0) + if (shape.empty()) setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); else - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - } else { + setShapeInfo(ShapeDescriptor(dtype, order, shape)); + + _buffer = std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace()); + _buffer->setToZeroBuffers(); + } + +//////////////////////////////////////////////////////////////////////// + NDArray::NDArray(const char order, const std::vector &shape, const std::vector& data, sd::DataType dtype, sd::LaunchContext * context) { + + if ((int) shape.size() > MAX_RANK) + throw std::invalid_argument("Rank of NDArray can't exceed 32"); + + _context = context; + _offset = 0; + + if (shape.size() == 0) { + if (data.size() == 0) + setShapeInfo(ShapeDescriptor::emptyDescriptor(dtype)); + else + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + } else { + setShapeInfo(ShapeDescriptor(dtype, order, shape)); + } + + if (lengthOf() != data.size()) { + nd4j_printf("NDArray constructor: data size [%i] doesn't match shape length [%i]\n", data.size(), lengthOf()); + throw std::runtime_error("Data size doesn't match shape"); + } + + _buffer = std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace(), true); + + for(Nd4jLong i=0; i < lengthOf(); ++i) { + BUILD_SINGLE_PARTIAL_SELECTOR(dtype, templatedDoubleAssign<, double>(buffer(), i, reinterpret_cast(data.data()), i), LIBND4J_TYPES); + } + tickWriteHost(); + syncToDevice(); + } + + +//////////////////////////////////////////////////////////////////////// + NDArray::NDArray(const NDArray *other, const bool copyStrides, sd::LaunchContext* context) { + + _context = context; + _offset = 0; + _isAttached = getContext()->getWorkspace() != nullptr; + + if (copyStrides) + setShapeInfo(ShapeDescriptor(other->_shapeInfo)); + else + setShapeInfo(ShapeDescriptor(other->dataType(), other->ordering(), other->shapeOf(), other->rankOf())); + + if (!isEmpty()) + _buffer = std::make_shared(lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); + } + +//////////////////////////////////////////////////////////////////////// + NDArray::NDArray(void* buffer, const char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext * context, const bool isBuffAlloc) { + + if (shape.empty()) + throw std::runtime_error("NDArray constructor: input shape is empty !"); + + if ((int) shape.size() > MAX_RANK) + throw std::invalid_argument("Rank of NDArray can't exceed 32"); + + _context = context; + _offset = 0; + _isAttached = getContext()->getWorkspace() != nullptr; + setShapeInfo(ShapeDescriptor(dtype, order, shape)); + + _buffer = std::make_shared(buffer, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, getContext()->getWorkspace()); } - if (lengthOf() != data.size()) { - nd4j_printf("NDArray constructor: data size [%i] doesn't match shape length [%i]\n", data.size(), lengthOf()); - throw std::runtime_error("Data size doesn't match shape"); - } - - _buffer = std::make_shared(lengthOf() * DataTypeUtils::sizeOf(dtype), dtype, getContext()->getWorkspace(), true); - - for(Nd4jLong i=0; i < lengthOf(); ++i) { - BUILD_SINGLE_PARTIAL_SELECTOR(dtype, templatedDoubleAssign<, double>(buffer(), i, reinterpret_cast(data.data()), i), LIBND4J_TYPES); - } - tickWriteHost(); - syncToDevice(); -} - - -//////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const NDArray *other, const bool copyStrides, sd::LaunchContext* context) { - - _context = context; - _offset = 0; - _isAttached = getContext()->getWorkspace() != nullptr; - - if (copyStrides) - setShapeInfo(ShapeDescriptor(other->_shapeInfo)); - else - setShapeInfo(ShapeDescriptor(other->dataType(), other->ordering(), other->shapeOf(), other->rankOf())); - - if (!isEmpty()) - _buffer = std::make_shared(lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); -} - -//////////////////////////////////////////////////////////////////////// -NDArray::NDArray(void* buffer, const char order, const std::vector &shape, sd::DataType dtype, sd::LaunchContext * context, const bool isBuffAlloc) { - - if (shape.empty()) - throw std::runtime_error("NDArray constructor: input shape is empty !"); - - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("Rank of NDArray can't exceed 32"); - - _context = context; - _offset = 0; - _isAttached = getContext()->getWorkspace() != nullptr; - - setShapeInfo(ShapeDescriptor(dtype, order, shape)); - - _buffer = std::make_shared(buffer, lengthOf() * sizeOfT(), dataType(), isBuffAlloc, getContext()->getWorkspace()); -} - //////////////////////////////////////////////////////////////////////// // creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros -NDArray::NDArray(const Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides, sd::LaunchContext * context, const bool nullify) { + NDArray::NDArray(const Nd4jLong* shapeInfo, const sd::DataType dtype, const bool copyStrides, sd::LaunchContext * context, const bool nullify) { - if (shapeInfo == nullptr) - throw std::runtime_error("NDArray constructor: can't be initalized without shapeinfo"); + if (shapeInfo == nullptr) + throw std::runtime_error("NDArray constructor: can't be initalized without shapeinfo"); - if ((int) shapeInfo[0] > MAX_RANK) - throw std::invalid_argument("Rank of NDArray can't exceed 32"); + if ((int) shapeInfo[0] > MAX_RANK) + throw std::invalid_argument("Rank of NDArray can't exceed 32"); - _context = context; - _offset = 0; + _context = context; + _offset = 0; - if (copyStrides) - setShapeInfo(ShapeDescriptor(shapeInfo, dtype)); - else - setShapeInfo(ShapeDescriptor(dtype, shape::order(shapeInfo), shape::shapeOf(shapeInfo), shape::rank(shapeInfo))); + if (copyStrides) + setShapeInfo(ShapeDescriptor(shapeInfo, dtype)); + else + setShapeInfo(ShapeDescriptor(dtype, shape::order(shapeInfo), shape::shapeOf(shapeInfo), shape::rank(shapeInfo))); - if (!isEmpty()) { - _buffer = std::make_shared(lengthOf() * sizeOfT(), dtype, getContext()->getWorkspace()); + if (!isEmpty()) { + _buffer = std::make_shared(lengthOf() * sizeOfT(), dtype, getContext()->getWorkspace()); - if (nullify) - _buffer->setToZeroBuffers(); + if (nullify) + _buffer->setToZeroBuffers(); + } } -} //////////////////////////////////////////////////////////////////////// // scalar constructor -NDArray::NDArray(sd::DataType dtype, sd::LaunchContext* context, const bool isScalar) { + NDArray::NDArray(sd::DataType dtype, sd::LaunchContext* context, const bool isScalar) { - _context = context; - _offset = 0; - _isAttached = getContext()->getWorkspace() != nullptr; + _context = context; + _offset = 0; + _isAttached = getContext()->getWorkspace() != nullptr; - if (isScalar) { - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - _buffer = std::make_shared(sizeOfT(), dtype, getContext()->getWorkspace()); - _buffer->setToZeroBuffers(); + if (isScalar) { + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + _buffer = std::make_shared(sizeOfT(), dtype, getContext()->getWorkspace()); + _buffer->setToZeroBuffers(); + } + else + setShapeInfo(ConstantShapeHelper::getInstance().emptyShapeInfo(dtype)); } - else - setShapeInfo(ConstantShapeHelper::getInstance().emptyShapeInfo(dtype)); -} ////////////////////////////////////////////////////////////////////////// // move constructor -NDArray::NDArray(NDArray&& other) noexcept { + NDArray::NDArray(NDArray&& other) noexcept { - _isView = other._isView; - _buffer = other._buffer; - _shapeInfo = other._shapeInfo; - _shapeInfoD = other._shapeInfoD; - _context = other._context; - _dataType = other._dataType; - _length = other._length; - _offset = other._offset; + _isView = other._isView; + _buffer = other._buffer; + _shapeInfo = other._shapeInfo; + _shapeInfoD = other._shapeInfoD; + _context = other._context; + _dataType = other._dataType; + _length = other._length; + _offset = other._offset; - other._buffer = std::make_shared(); - other._shapeInfo = other._shapeInfoD = nullptr; - other._length = 0; -} + other._buffer = std::make_shared(); + other._shapeInfo = other._shapeInfoD = nullptr; + other._length = 0; + } //////////////////////////////////////////////////////////////////////// //constructor, create empty array at given workspace -NDArray::NDArray(sd::LaunchContext * context) { - _buffer = std::make_shared(); - _shapeInfo = nullptr; - _shapeInfoD = nullptr; - _offset = 0; - _context = context; - _length = 0; -} + NDArray::NDArray(sd::LaunchContext * context) { + _buffer = std::make_shared(); + _shapeInfo = nullptr; + _shapeInfoD = nullptr; + _offset = 0; + _context = context; + _length = 0; + } //////////////////////////////////////////////////////////////////////// // creates new NDArray using shape information from "shapeInfo" array, set all elements in new array to be zeros, set dtype as array type @@ -288,227 +288,227 @@ NDArray::NDArray(sd::LaunchContext * context) { } ////////////////////////////////////////////////////////////////////////// -NDArray::NDArray(std::shared_ptr buffer, const char order, const std::vector &shape, sd::LaunchContext* context) { + NDArray::NDArray(std::shared_ptr buffer, const char order, const std::vector &shape, sd::LaunchContext* context) { - if (shape.empty()) - throw std::runtime_error("NDArray constructor: input shape is empty !"); + if (shape.empty()) + throw std::runtime_error("NDArray constructor: input shape is empty !"); - if ((int) shape.size() > MAX_RANK) - throw std::invalid_argument("NDArray constructor: rank of NDArray can't exceed 32"); + if ((int) shape.size() > MAX_RANK) + throw std::invalid_argument("NDArray constructor: rank of NDArray can't exceed 32"); - _context = context; - _offset = 0; + _context = context; + _offset = 0; - setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape)); + setShapeInfo(ShapeDescriptor(buffer->getDataType(), order, shape)); - _buffer = buffer; + _buffer = buffer; - _isView = _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); -} + _isView = _length * DataTypeUtils::sizeOf(_dataType) < buffer->getLenInBytes(); + } ///////////////////////////////////////////////////////////////////////// // u16 string constructors -NDArray::NDArray(const std::u16string& u16string, sd::DataType dtype, sd::LaunchContext* context) { + NDArray::NDArray(const std::u16string& u16string, sd::DataType dtype, sd::LaunchContext* context) { - if (!DataTypeUtils::isS(dtype)) { - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - } - - if (!unicode::isStringValidU16(u16string.data(), u16string.data() + u16string.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - - // one word that is why used 1 - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - - Nd4jLong dataLength = [&] { - if (dtype == DataType::UTF16) { - return static_cast(u16string.size() * sizeof(uint16_t)); + if (!DataTypeUtils::isS(dtype)) { + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); } - if (dtype == DataType::UTF32) { - return unicode::offsetUtf16StringInUtf32(u16string.data(), u16string.size()); + + if (!unicode::isStringValidU16(u16string.data(), u16string.data() + u16string.size())) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); } - return unicode::offsetUtf16StringInUtf8(u16string.data(), u16string.size()); - }(); - Nd4jLong offsets[2] = { 0 , dataLength }; + // one word that is why used 1 + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + Nd4jLong dataLength = [&] { + if (dtype == DataType::UTF16) { + return static_cast(u16string.size() * sizeof(uint16_t)); + } + if (dtype == DataType::UTF32) { + return unicode::offsetUtf16StringInUtf32(u16string.data(), u16string.size()); + } + return unicode::offsetUtf16StringInUtf8(u16string.data(), u16string.size()); + }(); - _context = context; - _isAttached = getContext()->getWorkspace() != nullptr; - _offset = 0; + Nd4jLong offsets[2] = { 0 , dataLength }; - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; - auto data = reinterpret_cast(bufferAsT() + headerLength); - if (dtype == DataType::UTF8) { - unicode::utf16to8(u16string.data(), data, u16string.size()); + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + + memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + if (dtype == DataType::UTF8) { + unicode::utf16to8(u16string.data(), data, u16string.size()); + } + else if (dtype == DataType::UTF16) { + memcpy(data, u16string.data(), dataLength); + } + else { + unicode::utf16to32(u16string.data(), data, u16string.size()); + } + + tickWriteHost(); + syncToDevice(); } - else if (dtype == DataType::UTF16) { - memcpy(data, u16string.data(), dataLength); - } - else { - unicode::utf16to32(u16string.data(), data, u16string.size()); - } - - tickWriteHost(); - syncToDevice(); -} ///////////////////////////////////////////////////////////////////////// // u32 string constructors -NDArray::NDArray(const std::u32string& u32string, sd::DataType dtype, sd::LaunchContext* context) { + NDArray::NDArray(const std::u32string& u32string, sd::DataType dtype, sd::LaunchContext* context) { - if (!DataTypeUtils::isS(dtype)) { - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - } - - if (!unicode::isStringValidU32(u32string.data(), u32string.data() + u32string.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - // one word that is why used 1 - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - - Nd4jLong dataLength = [&] { - if (dtype == DataType::UTF16) { - return unicode::offsetUtf32StringInUtf16(u32string.data(), u32string.size()); + if (!DataTypeUtils::isS(dtype)) { + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); } - if (dtype == DataType::UTF32) { - return static_cast(sizeof(uint32_t) * u32string.size()); + + if (!unicode::isStringValidU32(u32string.data(), u32string.data() + u32string.size())) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); } - return unicode::offsetUtf32StringInUtf8(u32string.data(), u32string.size()); - }(); + // one word that is why used 1 + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - Nd4jLong offsets[2] = { 0 , dataLength }; + Nd4jLong dataLength = [&] { + if (dtype == DataType::UTF16) { + return unicode::offsetUtf32StringInUtf16(u32string.data(), u32string.size()); + } + if (dtype == DataType::UTF32) { + return static_cast(sizeof(uint32_t) * u32string.size()); + } + return unicode::offsetUtf32StringInUtf8(u32string.data(), u32string.size()); + }(); - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + Nd4jLong offsets[2] = { 0 , dataLength }; - _context = context; - _isAttached = getContext()->getWorkspace() != nullptr; - _offset = 0; + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; - memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - auto data = reinterpret_cast(bufferAsT() + headerLength); - if (dtype == DataType::UTF8) { - unicode::utf32to8(u32string.data(), data, u32string.size()); + memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + if (dtype == DataType::UTF8) { + unicode::utf32to8(u32string.data(), data, u32string.size()); + } + else if (dtype == DataType::UTF16) { + unicode::utf32to16(u32string.data(), data, u32string.size()); + } + else { + memcpy(data, u32string.data(), u32string.size() * sizeof(uint32_t)); + } + + tickWriteHost(); + syncToDevice(); } - else if (dtype == DataType::UTF16) { - unicode::utf32to16(u32string.data(), data, u32string.size()); - } - else { - memcpy(data, u32string.data(), u32string.size() * sizeof(uint32_t)); - } - - tickWriteHost(); - syncToDevice(); -} ///////////////////////////////////////////////////////////////////////// // u8 string constructors -NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) { + NDArray::NDArray(const std::string& str, sd::DataType dtype, sd::LaunchContext* context) { - if (!DataTypeUtils::isS(dtype)) { - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - } - - if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); - } - - // one word that is why used 1 - auto headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - - Nd4jLong dataLength = [&] { - if (dtype == DataType::UTF16) { - return unicode::offsetUtf8StringInUtf16(str.data(), str.size()); + if (!DataTypeUtils::isS(dtype)) { + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); } - if (dtype == DataType::UTF32) { - return unicode::offsetUtf8StringInUtf32(str.data(), str.size()); - } - return static_cast(str.size()); - }(); - Nd4jLong offsets[2] = { 0 , dataLength }; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - - _context = context; - _isAttached = getContext()->getWorkspace() != nullptr; - _offset = 0; - - setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); - - memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - if (dtype == DataType::UTF8) { - memcpy(data, str.data(), str.size()); - } - else if (dtype == DataType::UTF16) { - unicode::utf8to16(str.data(), data, str.size()); - } - else { - unicode::utf8to32(str.data(), data, str.size()); - } - - tickWriteHost(); - syncToDevice(); -} -///////////////////////////////////////////////////////////////////////// -// constructors for vector of strings -NDArray::NDArray(const std::vector& shape, const std::vector& string, const sd::DataType dataType, sd::LaunchContext* context) { - - if (!DataTypeUtils::isS(dataType)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - - for (const auto& str : string) { - if (!unicode::isStringValidU8(str, str + std::char_traits::length(str)) ) { + if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); } - } - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); + // one word that is why used 1 + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(1); - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dataType == DataType::UTF16) - return unicode::offsetUtf8StringInUtf16(string[e], std::char_traits::length(string[e])); - if (dataType == DataType::UTF32) - return unicode::offsetUtf8StringInUtf32(string[e], std::char_traits::length(string[e])); - return static_cast(std::char_traits::length(string[e])); + Nd4jLong dataLength = [&] { + if (dtype == DataType::UTF16) { + return unicode::offsetUtf8StringInUtf16(str.data(), str.size()); + } + if (dtype == DataType::UTF32) { + return unicode::offsetUtf8StringInUtf32(str.data(), str.size()); + } + return static_cast(str.size()); }(); + + Nd4jLong offsets[2] = { 0 , dataLength }; + + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + + _context = context; + _isAttached = getContext()->getWorkspace() != nullptr; + _offset = 0; + + setShapeInfo(ShapeDescriptor::scalarDescriptor(dtype)); + + memcpy(bufferAsT(), &offsets[0], 2 * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + if (dtype == DataType::UTF8) { + memcpy(data, str.data(), str.size()); + } + else if (dtype == DataType::UTF16) { + unicode::utf8to16(str.data(), data, str.size()); + } + else { + unicode::utf8to32(str.data(), data, str.size()); + } + + tickWriteHost(); + syncToDevice(); } - offsets[string.size()] = dataLength; +///////////////////////////////////////////////////////////////////////// +// constructors for vector of strings + NDArray::NDArray(const std::vector& shape, const std::vector& string, const sd::DataType dataType, sd::LaunchContext* context) { - _buffer = std::make_shared(headerLength + dataLength, dataType, context->getWorkspace(), true); + if (!DataTypeUtils::isS(dataType)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - _context = context; - _offset = 0; + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); + for (const auto& str : string) { + if (!unicode::isStringValidU8(str, str + std::char_traits::length(str)) ) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } + } - _isView = false; + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - setAttached(context->getWorkspace() != nullptr); + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dataType == DataType::UTF16) + return unicode::offsetUtf8StringInUtf16(string[e], std::char_traits::length(string[e])); + if (dataType == DataType::UTF32) + return unicode::offsetUtf8StringInUtf32(string[e], std::char_traits::length(string[e])); + return static_cast(std::char_traits::length(string[e])); + }(); + } + offsets[string.size()] = dataLength; - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + _buffer = std::make_shared(headerLength + dataLength, dataType, context->getWorkspace(), true); - auto data = reinterpret_cast(bufferAsT() + headerLength); + _context = context; + _offset = 0; - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { + setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR{ + for (auto e = start; e < stop; e++) { auto cdata = data + offsets[e]; if (dataType == DataType::UTF16) { unicode::utf8to16(string[e], cdata, std::char_traits::length(string[e])); @@ -519,1125 +519,1125 @@ NDArray::NDArray(const std::vector& shape, const std::vector::length(string[e])); } - } - }; + } + }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - tickWriteHost(); - syncToDevice(); -} + tickWriteHost(); + syncToDevice(); + } ///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, const sd::DataType dataType, sd::LaunchContext* context) { + NDArray::NDArray(const std::vector& shape, const std::vector& string, const sd::DataType dataType, sd::LaunchContext* context) { - if (!DataTypeUtils::isS(dataType)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); + if (!DataTypeUtils::isS(dataType)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - for (const auto& str : string) { - if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + for (const auto& str : string) { + if (!unicode::isStringValidU8(str.data(), str.data() + str.size())) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } } - } - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dataType == DataType::UTF16) - return unicode::offsetUtf8StringInUtf16(string[e].data(), string[e].size()); - if (dataType == DataType::UTF32) - return unicode::offsetUtf8StringInUtf32(string[e].data(), string[e].size()); - return static_cast(string[e].size()); - }(); - } - - offsets[string.size()] = dataLength; - - _buffer = std::make_shared(headerLength + dataLength, dataType, context->getWorkspace(), true); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); - - _isView = false; - - setAttached(context->getWorkspace() != nullptr); - - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dataType == DataType::UTF16) { - unicode::utf8to16(string[e].data(), cdata, string[e].size()); - } - else if (dataType == DataType::UTF32) { - unicode::utf8to32(string[e].data(), cdata, string[e].size()); - } - else { - memcpy(cdata, string[e].data(), string[e].size()); - } + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dataType == DataType::UTF16) + return unicode::offsetUtf8StringInUtf16(string[e].data(), string[e].size()); + if (dataType == DataType::UTF32) + return unicode::offsetUtf8StringInUtf32(string[e].data(), string[e].size()); + return static_cast(string[e].size()); + }(); } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + offsets[string.size()] = dataLength; + + _buffer = std::make_shared(headerLength + dataLength, dataType, context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dataType, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR{ + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dataType == DataType::UTF16) { + unicode::utf8to16(string[e].data(), cdata, string[e].size()); + } + else if (dataType == DataType::UTF32) { + unicode::utf8to32(string[e].data(), cdata, string[e].size()); + } + else { + memcpy(cdata, string[e].data(), string[e].size()); + } + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - tickWriteHost(); - syncToDevice(); -} + tickWriteHost(); + syncToDevice(); + } ///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { + NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - if (!DataTypeUtils::isS(dtype)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - for (const auto& str : string) { - if (!unicode::isStringValidU16(str.data(), str.data() + str.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + for (const auto& str : string) { + if (!unicode::isStringValidU16(str.data(), str.data() + str.size())) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } } - } - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dtype == DataType::UTF16) - return static_cast(sizeof(uint16_t) * string[e].size()); - if (dtype == DataType::UTF32) - return unicode::offsetUtf16StringInUtf32(string[e].data(), string[e].size()); - return unicode::offsetUtf16StringInUtf8(string[e].data(), string[e].size()); - }(); - } - offsets[string.size()] = dataLength; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); - - _isView = false; - - setAttached(context->getWorkspace() != nullptr); - - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dtype == DataType::UTF16) { - memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t)); - } - else if (dtype == DataType::UTF32) { - unicode::utf16to32(string[e].data(), cdata, string[e].size()); - } - else { - unicode::utf16to8(string[e].data(), cdata, string[e].size()); - } + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return static_cast(sizeof(uint16_t) * string[e].size()); + if (dtype == DataType::UTF32) + return unicode::offsetUtf16StringInUtf32(string[e].data(), string[e].size()); + return unicode::offsetUtf16StringInUtf8(string[e].data(), string[e].size()); + }(); } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + offsets[string.size()] = dataLength; - tickWriteHost(); - syncToDevice(); -} + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR{ + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint16_t)); + } + else if (dtype == DataType::UTF32) { + unicode::utf16to32(string[e].data(), cdata, string[e].size()); + } + else { + unicode::utf16to8(string[e].data(), cdata, string[e].size()); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); + } ///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { + NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - if (!DataTypeUtils::isS(dtype)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - for (const auto& str : string) { - if (!unicode::isStringValidU16(str, str + std::char_traits::length(str))) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + for (const auto& str : string) { + if (!unicode::isStringValidU16(str, str + std::char_traits::length(str))) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } } - } - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dtype == DataType::UTF16) - return static_cast(sizeof(uint16_t) * std::char_traits::length(string[e])); - if (dtype == DataType::UTF32) - return unicode::offsetUtf16StringInUtf32(string[e], std::char_traits::length(string[e])); - return unicode::offsetUtf16StringInUtf8(string[e], std::char_traits::length(string[e])); - }(); - } - offsets[string.size()] = dataLength; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); - - _isView = false; - - setAttached(context->getWorkspace() != nullptr); - - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dtype == DataType::UTF16) { - memcpy(cdata, string[e], std::char_traits::length(string[e]) * sizeof(uint16_t)); - } - else if (dtype == DataType::UTF32) { - unicode::utf16to32(string[e], cdata, std::char_traits::length(string[e])); - } - else { - unicode::utf16to8(string[e], cdata, std::char_traits::length(string[e])); - } + std::vector offsets(string.size() + 1); + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return static_cast(sizeof(uint16_t) * std::char_traits::length(string[e])); + if (dtype == DataType::UTF32) + return unicode::offsetUtf16StringInUtf32(string[e], std::char_traits::length(string[e])); + return unicode::offsetUtf16StringInUtf8(string[e], std::char_traits::length(string[e])); + }(); } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + offsets[string.size()] = dataLength; - tickWriteHost(); - syncToDevice(); -} + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + + auto func = PRAGMA_THREADS_FOR{ + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + memcpy(cdata, string[e], std::char_traits::length(string[e]) * sizeof(uint16_t)); + } + else if (dtype == DataType::UTF32) { + unicode::utf16to32(string[e], cdata, std::char_traits::length(string[e])); + } + else { + unicode::utf16to8(string[e], cdata, std::char_traits::length(string[e])); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); + } ///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { + NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - if (!DataTypeUtils::isS(dtype)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType, only string dataTypes have to be used"); - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - for (auto str : string) { - if (!unicode::isStringValidU32(str.data(), str.data() + str.size())) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + for (auto str : string) { + if (!unicode::isStringValidU32(str.data(), str.data() + str.size())) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } } - } - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - std::vector offsets(string.size() + 1); + std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dtype == DataType::UTF16) + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return unicode::offsetUtf32StringInUtf16(string[e].data(), string[e].size()); + if (dtype == DataType::UTF32) + return static_cast(sizeof(uint32_t) * string[e].size()); return unicode::offsetUtf32StringInUtf16(string[e].data(), string[e].size()); - if (dtype == DataType::UTF32) - return static_cast(sizeof(uint32_t) * string[e].size()); - return unicode::offsetUtf32StringInUtf16(string[e].data(), string[e].size()); - }(); - } - offsets[string.size()] = dataLength; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); - - _isView = false; - - setAttached(context->getWorkspace() != nullptr); - - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dtype == DataType::UTF16) { - unicode::utf32to16(string[e].data(), cdata, string[e].size()); - } - else if (dtype == DataType::UTF32) { - memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint32_t)); - } - else { - unicode::utf32to8(string[e].data(), cdata, string[e].size()); - } + }(); } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + offsets[string.size()] = dataLength; - tickWriteHost(); - syncToDevice(); -} + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + + _isView = false; + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR{ + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + unicode::utf32to16(string[e].data(), cdata, string[e].size()); + } + else if (dtype == DataType::UTF32) { + memcpy(cdata, string[e].data(), string[e].size() * sizeof(uint32_t)); + } + else { + unicode::utf32to8(string[e].data(), cdata, string[e].size()); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); + } ///////////////////////////////////////////////////////////////////////// -NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { + NDArray::NDArray(const std::vector& shape, const std::vector& string, sd::DataType dtype, sd::LaunchContext* context) { - if (!DataTypeUtils::isS(dtype)) - throw std::invalid_argument("NDArray::NDArray: invalid DataType used"); + if (!DataTypeUtils::isS(dtype)) + throw std::invalid_argument("NDArray::NDArray: invalid DataType used"); - if (shape::prodLong(shape.data(), shape.size()) != string.size()) - throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); + if (shape::prodLong(shape.data(), shape.size()) != string.size()) + throw std::invalid_argument("NDArray::NDArray: Number of strings should match length of array"); - for (const auto& str : string) { - if (!unicode::isStringValidU32(str, str + std::char_traits::length(str))) { - throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + for (const auto& str : string) { + if (!unicode::isStringValidU32(str, str + std::char_traits::length(str))) { + throw std::invalid_argument("NDArray::NDArray: invalid character in input string"); + } } - } - Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); + Nd4jLong headerLength = ShapeUtils::stringBufferHeaderRequirements(string.size()); - std::vector offsets(string.size() + 1); + std::vector offsets(string.size() + 1); - Nd4jLong dataLength = 0; - for (int e = 0; e < string.size(); e++) { - offsets[e] = dataLength; - dataLength += [&] { - if (dtype == DataType::UTF16) + Nd4jLong dataLength = 0; + for (int e = 0; e < string.size(); e++) { + offsets[e] = dataLength; + dataLength += [&] { + if (dtype == DataType::UTF16) + return unicode::offsetUtf32StringInUtf16(string[e], std::char_traits::length(string[e])); + if (dtype == DataType::UTF32) + return static_cast(sizeof(uint32_t) * std::char_traits::length(string[e])); return unicode::offsetUtf32StringInUtf16(string[e], std::char_traits::length(string[e])); - if (dtype == DataType::UTF32) - return static_cast(sizeof(uint32_t) * std::char_traits::length(string[e])); - return unicode::offsetUtf32StringInUtf16(string[e], std::char_traits::length(string[e])); - }(); - } - offsets[string.size()] = dataLength; - - _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); - - _context = context; - _offset = 0; - - setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); - - _isView = _length * DataTypeUtils::sizeOf(_dataType) < _buffer->getLenInBytes(); - - setAttached(context->getWorkspace() != nullptr); - - memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto data = reinterpret_cast(bufferAsT() + headerLength); - - auto func = PRAGMA_THREADS_FOR{ - for (auto e = start; e < stop; e++) { - auto cdata = data + offsets[e]; - if (dtype == DataType::UTF16) { - unicode::utf32to16(string[e], cdata, std::char_traits::length(string[e])); - } - else if (dtype == DataType::UTF32) { - memcpy(cdata, string[e], std::char_traits::length(string[e]) * sizeof(uint32_t)); - } - else { - unicode::utf32to8(string[e], cdata, std::char_traits::length(string[e])); - } + }(); } - }; - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + offsets[string.size()] = dataLength; - tickWriteHost(); - syncToDevice(); -} + _buffer = std::make_shared(headerLength + dataLength, dtype, context->getWorkspace(), true); + + _context = context; + _offset = 0; + + setShapeInfo(ShapeDescriptor(dtype, 'c', shape)); + + _isView = _length * DataTypeUtils::sizeOf(_dataType) < _buffer->getLenInBytes(); + + setAttached(context->getWorkspace() != nullptr); + + memcpy(bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto data = reinterpret_cast(bufferAsT() + headerLength); + + auto func = PRAGMA_THREADS_FOR{ + for (auto e = start; e < stop; e++) { + auto cdata = data + offsets[e]; + if (dtype == DataType::UTF16) { + unicode::utf32to16(string[e], cdata, std::char_traits::length(string[e])); + } + else if (dtype == DataType::UTF32) { + memcpy(cdata, string[e], std::char_traits::length(string[e]) * sizeof(uint32_t)); + } + else { + unicode::utf32to8(string[e], cdata, std::char_traits::length(string[e])); + } + } + }; + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + tickWriteHost(); + syncToDevice(); + } //////////////////////////////////////////////////////////////////////// // assignment operator NDArray& NDArray::operator=(const NDArray& other) { - if (this == &other || (_shapeInfo == other._shapeInfo && _shapeInfo == nullptr)) + if (this == &other || (_shapeInfo == other._shapeInfo && _shapeInfo == nullptr)) + return *this; + + if (_shapeInfo != nullptr && shape::equalsTypesAndShapesSoft(_shapeInfo, other._shapeInfo)) { + if(!other.isEmpty()) + this->assign(&other); + } + else { + _context = other._context; + _offset = 0; + setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), other.shapeOf(), other.rankOf())); + + if(!other.isEmpty()) { + _buffer = std::make_shared(other.lengthOf() * other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); + this->assign(&other); + } + else + _buffer = std::make_shared(); + } return *this; - - if (_shapeInfo != nullptr && shape::equalsTypesAndShapesSoft(_shapeInfo, other._shapeInfo)) { - if(!other.isEmpty()) - this->assign(&other); } - else { - _context = other._context; - _offset = 0; - setShapeInfo(ShapeDescriptor(other.dataType(), other.ordering(), other.shapeOf(), other.rankOf())); - if(!other.isEmpty()) { - _buffer = std::make_shared(other.lengthOf() * other.sizeOfT(), other.dataType(), other.getContext()->getWorkspace()); - this->assign(&other); - } - else - _buffer = std::make_shared(); + +////////////////////////////////////////////////////////////////////////// + bool NDArray::isC() const { + // TODO: this method must be implemented once we add support for complex numbers + return false; } - return *this; -} - ////////////////////////////////////////////////////////////////////////// -bool NDArray::isC() const { - // TODO: this method must be implemented once we add support for complex numbers - return false; -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::isS() const { - return (dataType() == DataType::UTF8 || - dataType() == DataType::UTF16 || - dataType() == DataType::UTF32); -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::isR() const { - auto xType = ArrayOptions::dataType(this->_shapeInfo); - return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8 || xType == BFLOAT16; -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::isZ() const { - // TODO: decide if we really want to exclude Bool here - return !isC() && !isR() && !isB() && !isS(); -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::isB() const { - return ArrayOptions::dataType(this->_shapeInfo) == BOOL; -} - -////////////////////////////////////////////////////////////////////////// -template -std::string NDArray::toStringValue(T value) { - std::ostringstream os ; - //throw the value into the string stream - os << value ; - //convert the string stream into a string and return - return os.str() ; -} - -////////////////////////////////////////////////////////////////////////// -template<> -std::string NDArray::toStringValue(float16 value) { - std::ostringstream os ; - //throw the value into the string stream - os << (float) value ; - //convert the string stream into a string and return - return os.str() ; -} - -////////////////////////////////////////////////////////////////////////// -template<> -std::string NDArray::toStringValue(bfloat16 value) { - std::ostringstream os ; - //throw the value into the string stream - os << (float) value ; - //convert the string stream into a string and return - return os.str() ; -} - -////////////////////////////////////////////////////////////////////////// -std::string NDArray::asIndexedString(Nd4jLong limit) { - std::ostringstream os; - os << "["; - if (limit < 1 || limit > this->lengthOf()) - limit = this->lengthOf(); - for (Nd4jLong e = 0; e < limit; e++) { - os << toStringValue(this->e(e)); - if (e < limit - 1) - os << ", "; + bool NDArray::isS() const { + return (dataType() == DataType::UTF8 || + dataType() == DataType::UTF16 || + dataType() == DataType::UTF32); } - os << "]"; - return os.str(); -} ////////////////////////////////////////////////////////////////////////// -std::string NDArray::asString(Nd4jLong limit) { - std::ostringstream os; - os << "["; - if (limit < 1 || limit > this->lengthOf()) - limit = this->lengthOf(); - for (Nd4jLong e = 0; e < limit; e++) { - if (this->isR()) + bool NDArray::isR() const { + auto xType = ArrayOptions::dataType(this->_shapeInfo); + return xType == FLOAT32 || xType == HALF || xType == DOUBLE || xType == FLOAT8 || xType == BFLOAT16; + } + +////////////////////////////////////////////////////////////////////////// + bool NDArray::isZ() const { + // TODO: decide if we really want to exclude Bool here + return !isC() && !isR() && !isB() && !isS(); + } + +////////////////////////////////////////////////////////////////////////// + bool NDArray::isB() const { + return ArrayOptions::dataType(this->_shapeInfo) == BOOL; + } + +////////////////////////////////////////////////////////////////////////// + template + std::string NDArray::toStringValue(T value) { + std::ostringstream os ; + //throw the value into the string stream + os << value ; + //convert the string stream into a string and return + return os.str() ; + } + +////////////////////////////////////////////////////////////////////////// + template<> + std::string NDArray::toStringValue(float16 value) { + std::ostringstream os ; + //throw the value into the string stream + os << (float) value ; + //convert the string stream into a string and return + return os.str() ; + } + +////////////////////////////////////////////////////////////////////////// + template<> + std::string NDArray::toStringValue(bfloat16 value) { + std::ostringstream os ; + //throw the value into the string stream + os << (float) value ; + //convert the string stream into a string and return + return os.str() ; + } + +////////////////////////////////////////////////////////////////////////// + std::string NDArray::asIndexedString(Nd4jLong limit) { + std::ostringstream os; + os << "["; + if (limit < 1 || limit > this->lengthOf()) + limit = this->lengthOf(); + for (Nd4jLong e = 0; e < limit; e++) { os << toStringValue(this->e(e)); - else if (this->isZ()) - os << toStringValue(this->e(e)); - else if (this->isB()) - os << toStringValue(this->e(e)); - else if (this->isS()) // todo add utf16 and utf32 - os << this->e(e); - if (e < limit - 1) - os << ", "; - } - os << "]"; - return os.str(); -} - -//////////////////////////////////////////////////////////////////////// -template -std::vector NDArray::getBufferAsVector() const { - std::vector vector(lengthOf()); - for (Nd4jLong e = 0; e < lengthOf(); e++) - vector[e] = this->e(e); - return vector; -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector() const, LIBND4J_TYPES); - -//////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeAsFlatVector() const { - std::vector vector(this->rankOf()); - for (int e = 0; e < this->rankOf(); e++) - vector[e] = static_cast(this->sizeAt(e)); - return vector; -} - -//////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeAsVector() const { - - std::vector vector(this->rankOf()); - for (int e = 0; e < this->rankOf(); e++) - vector[e] = this->sizeAt(e); - - return vector; -} - -//////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeAsVectorInt() const { - - std::vector vector(this->rankOf()); - for (int e = 0; e < this->rankOf(); e++) - vector[e] = static_cast(this->sizeAt(e)); - - return vector; -} - -//////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeInfoAsFlatVector() const { - int magicNumber = shape::shapeInfoLength(this->rankOf()); - std::vector vector(magicNumber); - - for (int e = 0; e < magicNumber; e++) - vector[e] = static_cast(_shapeInfo[e]); - - return vector; -} - -//////////////////////////////////////////////////////////////////////// -std::vector NDArray::getShapeInfoAsVector() const { - int magicNumber = shape::shapeInfoLength(this->rankOf()); - std::vector vector(magicNumber); - for (int e = 0; e < magicNumber; e++) - vector[e] = this->_shapeInfo[e]; - return vector; -} - -//////////////////////////////////////////////////////////////////////// -std::vector NDArray::asByteVector() { - - if (isS()) { - // string data type requires special treatment - syncToHost(); - auto numWords = this->lengthOf(); - auto offsetsBuffer = this->bufferAsT(); - auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords); - auto dataLength = offsetsBuffer[numWords]; - std::vector result(headerLength + dataLength); - - memcpy(result.data(), buffer(), headerLength + dataLength); - - return result; - } else { - // all other types are linear - std::vector result((unsigned long long) this->lengthOf() * sizeOfT()); - - if (this->isView()) { - auto tmp = this->dup(this->ordering()); - syncToHost(); - memcpy(result.data(), tmp.buffer(), (unsigned long long) lengthOf() * sizeOfT()); - } else { - syncToHost(); - memcpy(result.data(), buffer(), (unsigned long long) lengthOf() * sizeOfT()); + if (e < limit - 1) + os << ", "; } - return result; + os << "]"; + return os.str(); } -} ////////////////////////////////////////////////////////////////////////// -void NDArray::linspace(const double start) { - linspace(start, 1); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::linspace(const double start, const double step) { - if (isS()) - throw std::runtime_error("NDArray::linspace: you can't use this method on String array!"); - Nd4jLong numElements = this->lengthOf(); - for (Nd4jLong e = 0; e < numElements; e++) - this->p(e, start + (step * e)); -} + std::string NDArray::asString(Nd4jLong limit) { + std::ostringstream os; + os << "["; + if (limit < 1 || limit > this->lengthOf()) + limit = this->lengthOf(); + for (Nd4jLong e = 0; e < limit; e++) { + if (this->isR()) + os << toStringValue(this->e(e)); + else if (this->isZ()) + os << toStringValue(this->e(e)); + else if (this->isB()) + os << toStringValue(this->e(e)); + else if (this->isS()) // todo add utf16 and utf32 + os << this->e(e); + if (e < limit - 1) + os << ", "; + } + os << "]"; + return os.str(); + } //////////////////////////////////////////////////////////////////////// -void NDArray::streamline(char o) { - char order = o == 'a' ? this->ordering() : o; - syncToDevice(); - std::shared_ptr newBuffer = std::make_shared(this->lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); - auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(dataType(), order, rankOf(), shapeOf()); - NativeOpExecutioner::execTransformSame(getContext(), transform::Copy, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), newBuffer->primary(), - shapeBuffer.primary(), newBuffer->special(), - shapeBuffer.special(), nullptr, nullptr, nullptr); - setShapeInfo(shapeBuffer); - _buffer = newBuffer; - _offset = 0; - tickWriteDevice(); -} + template + std::vector NDArray::getBufferAsVector() const { + std::vector vector(lengthOf()); + for (Nd4jLong e = 0; e < lengthOf(); e++) + vector[e] = this->e(e); + return vector; + } + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector() const, LIBND4J_TYPES); + +//////////////////////////////////////////////////////////////////////// + std::vector NDArray::getShapeAsFlatVector() const { + std::vector vector(this->rankOf()); + for (int e = 0; e < this->rankOf(); e++) + vector[e] = static_cast(this->sizeAt(e)); + return vector; + } + +//////////////////////////////////////////////////////////////////////// + std::vector NDArray::getShapeAsVector() const { + + std::vector vector(this->rankOf()); + for (int e = 0; e < this->rankOf(); e++) + vector[e] = this->sizeAt(e); + + return vector; + } + +//////////////////////////////////////////////////////////////////////// + std::vector NDArray::getShapeAsVectorInt() const { + + std::vector vector(this->rankOf()); + for (int e = 0; e < this->rankOf(); e++) + vector[e] = static_cast(this->sizeAt(e)); + + return vector; + } + +//////////////////////////////////////////////////////////////////////// + std::vector NDArray::getShapeInfoAsFlatVector() const { + int magicNumber = shape::shapeInfoLength(this->rankOf()); + std::vector vector(magicNumber); + + for (int e = 0; e < magicNumber; e++) + vector[e] = static_cast(_shapeInfo[e]); + + return vector; + } + +//////////////////////////////////////////////////////////////////////// + std::vector NDArray::getShapeInfoAsVector() const { + int magicNumber = shape::shapeInfoLength(this->rankOf()); + std::vector vector(magicNumber); + for (int e = 0; e < magicNumber; e++) + vector[e] = this->_shapeInfo[e]; + return vector; + } + +//////////////////////////////////////////////////////////////////////// + std::vector NDArray::asByteVector() { + + if (isS()) { + // string data type requires special treatment + syncToHost(); + auto numWords = this->lengthOf(); + auto offsetsBuffer = this->bufferAsT(); + auto headerLength = ShapeUtils::stringBufferHeaderRequirements(numWords); + auto dataLength = offsetsBuffer[numWords]; + std::vector result(headerLength + dataLength); + + memcpy(result.data(), buffer(), headerLength + dataLength); + + return result; + } else { + // all other types are linear + std::vector result((unsigned long long) this->lengthOf() * sizeOfT()); + + if (this->isView()) { + auto tmp = this->dup(this->ordering()); + syncToHost(); + memcpy(result.data(), tmp.buffer(), (unsigned long long) lengthOf() * sizeOfT()); + } else { + syncToHost(); + memcpy(result.data(), buffer(), (unsigned long long) lengthOf() * sizeOfT()); + } + return result; + } + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::linspace(const double start) { + linspace(start, 1); + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::linspace(const double start, const double step) { + if (isS()) + throw std::runtime_error("NDArray::linspace: you can't use this method on String array!"); + Nd4jLong numElements = this->lengthOf(); + for (Nd4jLong e = 0; e < numElements; e++) + this->p(e, start + (step * e)); + } + +//////////////////////////////////////////////////////////////////////// + void NDArray::streamline(char o) { + char order = o == 'a' ? this->ordering() : o; + syncToDevice(); + std::shared_ptr newBuffer = std::make_shared(this->lengthOf() * sizeOfT(), dataType(), getContext()->getWorkspace()); + auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(dataType(), order, rankOf(), shapeOf()); + NativeOpExecutioner::execTransformSame(getContext(), transform::Copy, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), newBuffer->primary(), + shapeBuffer.primary(), newBuffer->special(), + shapeBuffer.special(), nullptr, nullptr, nullptr); + setShapeInfo(shapeBuffer); + _buffer = newBuffer; + _offset = 0; + tickWriteDevice(); + } //////////////////////////////////////////////////////////////////////// // move assignment operator -NDArray& NDArray::operator=(NDArray&& other) noexcept { - if (this == &other) + NDArray& NDArray::operator=(NDArray&& other) noexcept { + if (this == &other) + return *this; + + _isView = other._isView; + _buffer = other._buffer; + _shapeInfo = other._shapeInfo; + _shapeInfoD = other._shapeInfoD; + _context = other._context; + _dataType = other._dataType; + _length = other._length; + _offset = other._offset; + + other._buffer = std::make_shared(); + other._shapeInfo = other._shapeInfoD = nullptr; + other._length = 0; + return *this; - - _isView = other._isView; - _buffer = other._buffer; - _shapeInfo = other._shapeInfo; - _shapeInfoD = other._shapeInfoD; - _context = other._context; - _dataType = other._dataType; - _length = other._length; - _offset = other._offset; - - other._buffer = std::make_shared(); - other._shapeInfo = other._shapeInfoD = nullptr; - other._length = 0; - - return *this; -} + } //////////////////////////////////////////////////////////////////////// -template -NDArray& NDArray::operator=(const T scalar) { - this->assign(scalar); - return *this; -} -template ND4J_EXPORT NDArray& NDArray::operator=(const double scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const float scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const float16 scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const bfloat16 scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const Nd4jLong scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const int scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const int8_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const uint8_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const uint16_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const uint32_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const uint64_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const int16_t scalar); -template ND4J_EXPORT NDArray& NDArray::operator=(const bool scalar); + template + NDArray& NDArray::operator=(const T scalar) { + this->assign(scalar); + return *this; + } + template ND4J_EXPORT NDArray& NDArray::operator=(const double scalar); + template ND4J_EXPORT NDArray& NDArray::operator=(const float scalar); + template ND4J_EXPORT NDArray& NDArray::operator=(const float16 scalar); + template ND4J_EXPORT NDArray& NDArray::operator=(const bfloat16 scalar); + template ND4J_EXPORT NDArray& NDArray::operator=(const Nd4jLong scalar); + template ND4J_EXPORT NDArray& NDArray::operator=(const int scalar); + template ND4J_EXPORT NDArray& NDArray::operator=(const int8_t scalar); + template ND4J_EXPORT NDArray& NDArray::operator=(const uint8_t scalar); + template ND4J_EXPORT NDArray& NDArray::operator=(const uint16_t scalar); + template ND4J_EXPORT NDArray& NDArray::operator=(const uint32_t scalar); + template ND4J_EXPORT NDArray& NDArray::operator=(const uint64_t scalar); + template ND4J_EXPORT NDArray& NDArray::operator=(const int16_t scalar); + template ND4J_EXPORT NDArray& NDArray::operator=(const bool scalar); ////////////////////////////////////////////////////////////////////////// -void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCopyInBytes, Nd4jLong offsetThis, Nd4jLong offsetOther) { + void NDArray::copyBuffersContinuouslyFrom(const NDArray& other, size_t sizeToCopyInBytes, Nd4jLong offsetThis, Nd4jLong offsetOther) { - if(offsetThis == 0) - offsetThis = bufferOffset(); - if(offsetOther == 0) - offsetOther = other.bufferOffset(); + if(offsetThis == 0) + offsetThis = bufferOffset(); + if(offsetOther == 0) + offsetOther = other.bufferOffset(); - dataBuffer()->copyBufferFrom(*other.getDataBuffer(), sizeToCopyInBytes, offsetThis, offsetOther); -} + dataBuffer()->copyBufferFrom(*other.getDataBuffer(), sizeToCopyInBytes, offsetThis, offsetOther); + } //////////////////////////////////////////////////////////////////// // This method assigns values of given NDArray to this one -void NDArray::assign(const NDArray& other, bool allowParallelism) { + void NDArray::assign(const NDArray& other, bool allowParallelism) { - if (this == &other) - return; + if (this == &other) + return; - if (other.isEmpty()) { - if (!isEmpty()) { - throw std::runtime_error("Cannot assign empty array to non-empty array"); + if (other.isEmpty()) { + if (!isEmpty()) { + throw std::runtime_error("Cannot assign empty array to non-empty array"); + } + return; } - return; - } - if(isEmpty()) { - *this = other; - return; - } - - if (other.lengthOf() == 1) { - - if(lengthOf() == 1) { - NDArray::preparePrimaryUse({this}, {&other}); - BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.buffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {&other}); - this->syncToDevice(); + if(isEmpty()) { + *this = other; + return; } - else { - if (dataType() != other.dataType()) { - auto tmp = other.cast(dataType()); - NDArray::prepareSpecialUse({this}, {&tmp}); - NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {}); + + if (other.lengthOf() == 1) { + + if(lengthOf() == 1) { + NDArray::preparePrimaryUse({this}, {&other}); + BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.buffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {&other}); + this->syncToDevice(); } else { - NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&other}); + if (dataType() != other.dataType()) { + auto tmp = other.cast(dataType()); + NDArray::prepareSpecialUse({this}, {&tmp}); + NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr, allowParallelism); + NDArray::registerSpecialUse({this}, {}); + } + else { + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execScalar(getContext(), scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr, allowParallelism); + NDArray::registerSpecialUse({this}, {&other}); + } } } - } - else { - if (other.lengthOf() != lengthOf()) { - auto shapeThis = ShapeUtils::shapeAsString(this); - auto shapeThat = ShapeUtils::shapeAsString(&other); - nd4j_printf("Can't assign array: this shape %s; other shape: %s\n", shapeThis.c_str(), shapeThat.c_str()); - throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched"); - } + else { + if (other.lengthOf() != lengthOf()) { + auto shapeThis = ShapeUtils::shapeAsString(this); + auto shapeThat = ShapeUtils::shapeAsString(&other); + nd4j_printf("Can't assign array: this shape %s; other shape: %s\n", shapeThis.c_str(), shapeThat.c_str()); + throw std::runtime_error("NDArray::assign: lengths of arrays are mismatched"); + } - NDArray::prepareSpecialUse({this}, {&other}); - NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism); - NDArray::registerSpecialUse({this}, {&other}); + NDArray::prepareSpecialUse({this}, {&other}); + NativeOpExecutioner::execTransformAny(getContext(), transform::Assign, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr, allowParallelism); + NDArray::registerSpecialUse({this}, {&other}); + } } -} ////////////////////////////////////////////////////////////////////////// // This method assigns values of given NDArray to this one, wrt order -void NDArray::assign(const NDArray *other, bool allowParallelism) { - assign(*other, allowParallelism); -} + void NDArray::assign(const NDArray *other, bool allowParallelism) { + assign(*other, allowParallelism); + } ////////////////////////////////////////////////////////////////////////// -template -void NDArray::assign(const T& value, bool allowParallelism) { - // just fire scalar - auto temp = NDArrayFactory::create(dataType(), value, this->getContext()); + template + void NDArray::assign(const T& value, bool allowParallelism) { + // just fire scalar + auto temp = NDArrayFactory::create(dataType(), value, this->getContext()); - NDArray::prepareSpecialUse(std::vector{this}, std::vector{&temp}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.specialShapeInfo(), nullptr, allowParallelism); - NDArray::registerSpecialUse(std::vector{this}, std::vector{&temp}); -} -template ND4J_EXPORT void NDArray::assign(const double& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const float& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const float16& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const bfloat16& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const Nd4jLong& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const int& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const int8_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const int16_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const uint8_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const uint16_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const uint32_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const uint64_t& value, bool allowParallelism); -template ND4J_EXPORT void NDArray::assign(const bool& value, bool allowParallelism); + NDArray::prepareSpecialUse(std::vector{this}, std::vector{&temp}); + NativeOpExecutioner::execScalar(getContext(), sd::scalar::CopyPws, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), temp.buffer(), temp.shapeInfo(), temp.specialBuffer(), temp.specialShapeInfo(), nullptr, allowParallelism); + NDArray::registerSpecialUse(std::vector{this}, std::vector{&temp}); + } + template ND4J_EXPORT void NDArray::assign(const double& value, bool allowParallelism); + template ND4J_EXPORT void NDArray::assign(const float& value, bool allowParallelism); + template ND4J_EXPORT void NDArray::assign(const float16& value, bool allowParallelism); + template ND4J_EXPORT void NDArray::assign(const bfloat16& value, bool allowParallelism); + template ND4J_EXPORT void NDArray::assign(const Nd4jLong& value, bool allowParallelism); + template ND4J_EXPORT void NDArray::assign(const int& value, bool allowParallelism); + template ND4J_EXPORT void NDArray::assign(const int8_t& value, bool allowParallelism); + template ND4J_EXPORT void NDArray::assign(const int16_t& value, bool allowParallelism); + template ND4J_EXPORT void NDArray::assign(const uint8_t& value, bool allowParallelism); + template ND4J_EXPORT void NDArray::assign(const uint16_t& value, bool allowParallelism); + template ND4J_EXPORT void NDArray::assign(const uint32_t& value, bool allowParallelism); + template ND4J_EXPORT void NDArray::assign(const uint64_t& value, bool allowParallelism); + template ND4J_EXPORT void NDArray::assign(const bool& value, bool allowParallelism); ////////////////////////////////////////////////////////////////////////// -NDArray* NDArray::detach() { + NDArray* NDArray::detach() { - if (!isAttached()) - return this; + if (!isAttached()) + return this; - std::shared_ptr newBuffer = std::make_shared(lengthOf() * sizeOfT(), dataType()); + std::shared_ptr newBuffer = std::make_shared(lengthOf() * sizeOfT(), dataType()); - auto result = new NDArray(newBuffer, ShapeDescriptor(dataType(), ordering(), shapeOf(), rankOf())); + auto result = new NDArray(newBuffer, ShapeDescriptor(dataType(), ordering(), shapeOf(), rankOf())); - result->assign(*this); + result->assign(*this); - return result; -} + return result; + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::varianceNumber(sd::variance::Ops op, bool biasCorrected) { + NDArray NDArray::varianceNumber(sd::variance::Ops op, bool biasCorrected) { - NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); + NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); - NDArray::prepareSpecialUse({&res}, {this}); - NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo(), biasCorrected); - NDArray::registerSpecialUse({&res}, {this}); + NDArray::prepareSpecialUse({&res}, {this}); + NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo(), biasCorrected); + NDArray::registerSpecialUse({&res}, {this}); - return res; -} + return res; + } ////////////////////////////////////////////////////////////////////////// // This method returns sum of all elements of this NDArray -NDArray NDArray::sumNumber() const { - if (isS()) - throw std::runtime_error("NDArray::sumNumber: you can't use this method on String array!"); - NDArray res(dataType(), getContext()); + NDArray NDArray::sumNumber() const { + if (isS()) + throw std::runtime_error("NDArray::sumNumber: you can't use this method on String array!"); + NDArray res(dataType(), getContext()); - NDArray::prepareSpecialUse({&res}, {this}); - NativeOpExecutioner::execReduceSameScalar(getContext(), sd::reduce::SameOps::Sum, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); - NDArray::registerSpecialUse({&res}, {this}); + NDArray::prepareSpecialUse({&res}, {this}); + NativeOpExecutioner::execReduceSameScalar(getContext(), sd::reduce::SameOps::Sum, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); + NDArray::registerSpecialUse({&res}, {this}); - return res; -} + return res; + } ////////////////////////////////////////////////////////////////////////// // This method returns mean number of this NDArray -NDArray NDArray::meanNumber() const { + NDArray NDArray::meanNumber() const { - if (isS()) - throw std::runtime_error("NDArray::meanNumber: you can't use this method on String array!"); - NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); + if (isS()) + throw std::runtime_error("NDArray::meanNumber: you can't use this method on String array!"); + NDArray res(DataTypeUtils::pickFloatingType(dataType()), getContext()); - NDArray::prepareSpecialUse({&res}, {this}); - NativeOpExecutioner::execReduceFloatScalar(getContext(), sd::reduce::FloatOps::Mean, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); - NDArray::registerSpecialUse({&res}, {this}); - return res; -} + NDArray::prepareSpecialUse({&res}, {this}); + NativeOpExecutioner::execReduceFloatScalar(getContext(), sd::reduce::FloatOps::Mean, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); + NDArray::registerSpecialUse({&res}, {this}); + return res; + } ////////////////////////////////////////////////////////////////////////// -bool NDArray::hasNaNs() { - if (isS()) - throw std::runtime_error("NDArray::hasNaNs: you can't use this method on String array!"); - return this->reduceNumber(sd::reduce::IsNan, nullptr).e(0) > 0; -} + bool NDArray::hasNaNs() { + if (isS()) + throw std::runtime_error("NDArray::hasNaNs: you can't use this method on String array!"); + return this->reduceNumber(sd::reduce::IsNan, nullptr).e(0) > 0; + } ////////////////////////////////////////////////////////////////////////// -bool NDArray::hasInfs() { - if (isS()) - throw std::runtime_error("NDArray::hasInfs: you can't use this method on String array!"); - return this->reduceNumber(sd::reduce::IsInf, nullptr).e(0) > 0; -} + bool NDArray::hasInfs() { + if (isS()) + throw std::runtime_error("NDArray::hasInfs: you can't use this method on String array!"); + return this->reduceNumber(sd::reduce::IsInf, nullptr).e(0) > 0; + } ////////////////////////////////////////////////////////////////////////// -bool NDArray::isFinite() { - if (isS()) - throw std::runtime_error("NDArray::isFinite: you can't use this method on String array!"); - return this->reduceNumber(sd::reduce::IsInfOrNan, nullptr).e(0) == 0; -} + bool NDArray::isFinite() { + if (isS()) + throw std::runtime_error("NDArray::isFinite: you can't use this method on String array!"); + return this->reduceNumber(sd::reduce::IsInfOrNan, nullptr).e(0) == 0; + } ////////////////////////////////////////////////////////////////////////// -template -void NDArray::templatedSet(void *buffer, const Nd4jLong *indices, const void *value) { - auto t = reinterpret_cast(buffer); - const auto y = *(reinterpret_cast(value)); + template + void NDArray::templatedSet(void *buffer, const Nd4jLong *indices, const void *value) { + auto t = reinterpret_cast(buffer); + const auto y = *(reinterpret_cast(value)); - auto xOffset = shape::getOffset(shapeInfo(), indices); - t[xOffset] = static_cast(y); -} -BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong *indices, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); + auto xOffset = shape::getOffset(shapeInfo(), indices); + t[xOffset] = static_cast(y); + } + BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong *indices, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -template -void NDArray::templatedSet(void *buffer, const Nd4jLong offset, const void *value) { - auto t = reinterpret_cast(buffer); - const auto y = *(reinterpret_cast(value)); + template + void NDArray::templatedSet(void *buffer, const Nd4jLong offset, const void *value) { + auto t = reinterpret_cast(buffer); + const auto y = *(reinterpret_cast(value)); - t[offset] = static_cast(y); -} -BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong offset, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); + t[offset] = static_cast(y); + } + BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong offset, const void *value), LIBND4J_TYPES, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -void NDArray::setContext(sd::LaunchContext *context) { + void NDArray::setContext(sd::LaunchContext *context) { - _context = context; - if (getContext() == nullptr) - _context = sd::LaunchContext ::defaultContext(); // empty context for default cases -} + _context = context; + if (getContext() == nullptr) + _context = sd::LaunchContext ::defaultContext(); // empty context for default cases + } ////////////////////////////////////////////////////////////////////////// -void const* NDArray::bufferWithOffset(Nd4jLong offset) const { - return const_cast(buffer() != nullptr ? static_cast(buffer()) + (offset * sizeOfT()) : nullptr); -} + void const* NDArray::bufferWithOffset(Nd4jLong offset) const { + return const_cast(buffer() != nullptr ? static_cast(buffer()) + (offset * sizeOfT()) : nullptr); + } ////////////////////////////////////////////////////////////////////////// -void* NDArray::bufferWithOffset(Nd4jLong offset) { - return const_cast(buffer() != nullptr ? static_cast(buffer()) + (offset * sizeOfT()) : nullptr); -} + void* NDArray::bufferWithOffset(Nd4jLong offset) { + return const_cast(buffer() != nullptr ? static_cast(buffer()) + (offset * sizeOfT()) : nullptr); + } ////////////////////////////////////////////////////////////////////////// // eventually method reduces array by excluding its shapes along axes present in dimensions vector -NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims) const { + NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::vector& dimensions, const bool keepDims) const { - std::vector copy(dimensions); + std::vector copy(dimensions); - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, isR() ? dataType() : Environment::getInstance().defaultFloatDataType(), keepDims, false, getContext()->getWorkspace()); + auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, isR() ? dataType() : Environment::getInstance().defaultFloatDataType(), keepDims, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); - this->reduceAlongDimension(op, result, copy, keepDims, false); + this->reduceAlongDimension(op, result, copy, keepDims, false); - return result; -} + return result; + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::vector& dimensions, const bool keepDims) const { + NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::vector& dimensions, const bool keepDims) const { - std::vector copy(dimensions); + std::vector copy(dimensions); - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, keepDims, false, getContext()->getWorkspace()); + auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, keepDims, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); - reduceAlongDimension(op, result, copy, keepDims, false); + reduceAlongDimension(op, result, copy, keepDims, false); - return result; -} + return result; + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims) const { + NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::vector& dimensions, const bool keepDims) const { - std::vector copy(dimensions); + std::vector copy(dimensions); - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::BOOL, keepDims, false, getContext()->getWorkspace()); + auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::BOOL, keepDims, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); - reduceAlongDimension(op, result, copy, keepDims, false); + reduceAlongDimension(op, result, copy, keepDims, false); - return result; -} + return result; + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::vector& dimensions, const bool keepDims) const { + NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::vector& dimensions, const bool keepDims) const { - std::vector copy(dimensions); + std::vector copy(dimensions); - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, keepDims, false, getContext()->getWorkspace()); + auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, keepDims, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); + NDArray result(newShape, true, getContext()); - reduceAlongDimension(op, result, copy, keepDims, false); + reduceAlongDimension(op, result, copy, keepDims, false); - return result; -} + return result; + } ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims) const { - return reduceAlongDimension(op, std::vector(dimensions), keepDims); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims) const { - return reduceAlongDimension(op, std::vector(dimensions), keepDims); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims) const { - return reduceAlongDimension(op, std::vector(dimensions), keepDims); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims) const { - return reduceAlongDimension(op, std::vector(dimensions), keepDims); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceNumber(sd::reduce::FloatOps op, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!"); - - auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataTypeUtils::pickFloatingType(dataType())); - NDArray result(shape, true, this->getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceNumber(sd::reduce::SameOps op, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!"); - - NDArray result(dataType(), getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceNumber(sd::reduce::BoolOps op, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!"); - - auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::BOOL); - NDArray result(shape, true, this->getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reduceNumber(sd::reduce::LongOps op, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!"); - - auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT64); - NDArray result(shape, true, this->getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::reduceNumber(sd::reduce::FloatOps op, NDArray& target, void *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!"); - if(target.lengthOf() != 1 || target.dataType() != DataTypeUtils::pickFloatingType(dataType())) - throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::reduceNumber(sd::reduce::SameOps op, NDArray& target, void *extraParams) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!"); - if(target.lengthOf() != 1 || target.dataType() != dataType()) - throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::reduceNumber(sd::reduce::BoolOps op, NDArray& target, void *extraParams) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!"); - if(target.lengthOf() != 1 || target.dataType() != DataType::BOOL) - throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::reduceNumber(sd::reduce::LongOps op, NDArray& target, void *extraParams) const { - - if (isS()) - throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!"); - if(target.lengthOf() != 1 || target.dataType() != DataType::INT64) - throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!"); - - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - NDArray::registerSpecialUse({&target}, {this}); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::indexReduceNumber(sd::indexreduce::Ops op, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::indexReduceNumber: you can't use this method on String array!"); - - auto res = NDArrayFactory::create(0); - - NDArray::NDArray::prepareSpecialUse({&res}, {this}); - NativeOpExecutioner::execIndexReduceScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams == nullptr ? nullptr : extraParams->argumentsAsT(this->dataType()), res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); - NDArray::NDArray::registerSpecialUse({&res}, {this}); - - return res; -} - -////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::tensorsAlongDimension(std::initializer_list dimensions) const { - return tensorsAlongDimension(std::vector(dimensions)); -} - -////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::tensorsAlongDimension(const std::vector& dimensions) const { - std::vector copy(dimensions); - shape::checkDimensions(rankOf(), copy); - - Nd4jLong tadLength = shape::tadLength(this->_shapeInfo, copy.data(), copy.size()); - Nd4jLong numTads = this->lengthOf() / tadLength; - - return numTads; -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::printShapeInfo(const char * msg) const { - - int rank = shape::rank(_shapeInfo); - int lim = shape::shapeInfoLength(rank); - - if(msg != nullptr) - printf("shapeInfo %s: [", msg); - else - printf("shapeInfo: ["); - - printf("%i, ", rank); - for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){ - if(i == rank + 1) - printf(" "); - printf("%lld,", _shapeInfo[i]); + NDArray NDArray::reduceAlongDimension(sd::reduce::FloatOps op, const std::initializer_list& dimensions, const bool keepDims) const { + return reduceAlongDimension(op, std::vector(dimensions), keepDims); } - printf(" %lld,", shape::type(_shapeInfo)); - printf("%lld,", shape::elementWiseStride(_shapeInfo)); - printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo)); - - fflush(stdout); -} ////////////////////////////////////////////////////////////////////////// -void NDArray::printBuffer(const char* msg, Nd4jLong limit, const bool sync) const{ - if (sync) - syncToHost(); + NDArray NDArray::reduceAlongDimension(sd::reduce::SameOps op, const std::initializer_list& dimensions, const bool keepDims) const { + return reduceAlongDimension(op, std::vector(dimensions), keepDims); + } - if (limit == -1) - limit = (int) this->lengthOf(); +////////////////////////////////////////////////////////////////////////// + NDArray NDArray::reduceAlongDimension(sd::reduce::BoolOps op, const std::initializer_list& dimensions, const bool keepDims) const { + return reduceAlongDimension(op, std::vector(dimensions), keepDims); + } - if (msg != nullptr) - printf("%s: [", msg); - else - printf("["); - if (this->isR()) { - for (Nd4jLong e = 0; e < limit; e++) { - if (e) - printf(", "); - printf("%f", this->e(e)); +////////////////////////////////////////////////////////////////////////// + NDArray NDArray::reduceAlongDimension(sd::reduce::LongOps op, const std::initializer_list& dimensions, const bool keepDims) const { + return reduceAlongDimension(op, std::vector(dimensions), keepDims); + } + +////////////////////////////////////////////////////////////////////////// + NDArray NDArray::reduceNumber(sd::reduce::FloatOps op, void *extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!"); + + auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataTypeUtils::pickFloatingType(dataType())); + NDArray result(shape, true, this->getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this}); + + return result; + } + +////////////////////////////////////////////////////////////////////////// + NDArray NDArray::reduceNumber(sd::reduce::SameOps op, void *extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!"); + + NDArray result(dataType(), getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this}); + + return result; + } + +////////////////////////////////////////////////////////////////////////// + NDArray NDArray::reduceNumber(sd::reduce::BoolOps op, void *extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!"); + + auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::BOOL); + NDArray result(shape, true, this->getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this}); + + return result; + } + +////////////////////////////////////////////////////////////////////////// + NDArray NDArray::reduceNumber(sd::reduce::LongOps op, void *extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!"); + + auto shape = ConstantShapeHelper::getInstance().scalarShapeInfo(DataType::INT64); + NDArray result(shape, true, this->getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this}); + + return result; + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::reduceNumber(sd::reduce::FloatOps op, NDArray& target, void *extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!"); + if(target.lengthOf() != 1 || target.dataType() != DataTypeUtils::pickFloatingType(dataType())) + throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + NDArray::registerSpecialUse({&target}, {this}); + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::reduceNumber(sd::reduce::SameOps op, NDArray& target, void *extraParams) const { + + if (isS()) + throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!"); + if(target.lengthOf() != 1 || target.dataType() != dataType()) + throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + NDArray::registerSpecialUse({&target}, {this}); + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::reduceNumber(sd::reduce::BoolOps op, NDArray& target, void *extraParams) const { + + if (isS()) + throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!"); + if(target.lengthOf() != 1 || target.dataType() != DataType::BOOL) + throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + NDArray::registerSpecialUse({&target}, {this}); + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::reduceNumber(sd::reduce::LongOps op, NDArray& target, void *extraParams) const { + + if (isS()) + throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!"); + if(target.lengthOf() != 1 || target.dataType() != DataType::INT64) + throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!"); + + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + NDArray::registerSpecialUse({&target}, {this}); + } + +////////////////////////////////////////////////////////////////////////// + NDArray NDArray::indexReduceNumber(sd::indexreduce::Ops op, ExtraArguments *extraParams) { + if (isS()) + throw std::runtime_error("NDArray::indexReduceNumber: you can't use this method on String array!"); + + auto res = NDArrayFactory::create(0); + + NDArray::NDArray::prepareSpecialUse({&res}, {this}); + NativeOpExecutioner::execIndexReduceScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams == nullptr ? nullptr : extraParams->argumentsAsT(this->dataType()), res.buffer(), res.shapeInfo(), res.specialBuffer(), res.specialShapeInfo()); + NDArray::NDArray::registerSpecialUse({&res}, {this}); + + return res; + } + +////////////////////////////////////////////////////////////////////////// + Nd4jLong NDArray::tensorsAlongDimension(std::initializer_list dimensions) const { + return tensorsAlongDimension(std::vector(dimensions)); + } + +////////////////////////////////////////////////////////////////////////// + Nd4jLong NDArray::tensorsAlongDimension(const std::vector& dimensions) const { + std::vector copy(dimensions); + shape::checkDimensions(rankOf(), copy); + + Nd4jLong tadLength = shape::tadLength(this->_shapeInfo, copy.data(), copy.size()); + Nd4jLong numTads = this->lengthOf() / tadLength; + + return numTads; + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::printShapeInfo(const char * msg) const { + + int rank = shape::rank(_shapeInfo); + int lim = shape::shapeInfoLength(rank); + + if(msg != nullptr) + printf("shapeInfo %s: [", msg); + else + printf("shapeInfo: ["); + + printf("%i, ", rank); + for (int i = 1; i < shape::shapeInfoLength(rank) - 3; i++){ + if(i == rank + 1) + printf(" "); + printf("%lld,", _shapeInfo[i]); } + printf(" %lld,", shape::type(_shapeInfo)); + printf("%lld,", shape::elementWiseStride(_shapeInfo)); + printf("%lld]\n", (Nd4jLong)shape::order(_shapeInfo)); + + fflush(stdout); } - else if (this->isZ()) { - for (Nd4jLong e = 0; e < limit; e++) { - if (this->dataType() != sd::DataType::INT64 && this->dataType() != sd::DataType::UINT64) - printf("%d", this->e(e)); - else - printf("%llu", this->e(e)); - if (e < limit - 1) - printf(", "); + +////////////////////////////////////////////////////////////////////////// + void NDArray::printBuffer(const char* msg, Nd4jLong limit, const bool sync) const{ + if (sync) + syncToHost(); + + if (limit == -1) + limit = (int) this->lengthOf(); + + if (msg != nullptr) + printf("%s: [", msg); + else + printf("["); + if (this->isR()) { + for (Nd4jLong e = 0; e < limit; e++) { + if (e) + printf(", "); + printf("%f", this->e(e)); + } } - } - else if (this->isB()) { - for (Nd4jLong e = 0; e < limit; e++) { - if (this->e(e)) - printf("true"); - else - printf("false"); - if (e < limit - 1) - printf(", "); + else if (this->isZ()) { + for (Nd4jLong e = 0; e < limit; e++) { + if (this->dataType() != sd::DataType::INT64 && this->dataType() != sd::DataType::UINT64) + printf("%d", this->e(e)); + else + printf("%llu", this->e(e)); + if (e < limit - 1) + printf(", "); + } } - } - else if (this->isS()) { - // todo do we need this print offsets - /* + else if (this->isB()) { + for (Nd4jLong e = 0; e < limit; e++) { + if (this->e(e)) + printf("true"); + else + printf("false"); + if (e < limit - 1) + printf(", "); + } + } + else if (this->isS()) { + // todo do we need this print offsets + /* for (Nd4jLong e = 0; e < limit; e++) { printf("\"%lld\"", this->getOffset(e)); if (e < limit - 1) @@ -1645,2893 +1645,2893 @@ void NDArray::printBuffer(const char* msg, Nd4jLong limit, const bool sync) cons } printf("]\n["); */ - for (Nd4jLong e = 0; e < limit; e++) { - printf("\"%s\"", this->e(e).c_str()); - if (e < limit - 1) - printf(", "); - } - } - printf("]\n"); - fflush(stdout); -} - -////////////////////////////////////////////////////////////////////////// -// print element by element consequently in a way they (elements) are stored in physical memory -void NDArray::printLinearBuffer() const { - - syncToHost(); - - const auto ews = this->ews() > 0 ? this->ews() : 1; - const auto len = this->lengthOf(); - - printf("["); - - if (this->dataType() == sd::DataType::INT32) { - for(Nd4jLong e = 0; e < len; e++) - printf("%d, ", this->bufferAsT()[e * ews]); - } - else if(this->dataType() == sd::DataType::INT64) { - for(Nd4jLong e = 0; e < len; e++) - printf("%lld, ", this->bufferAsT()[e * ews]); - } - else if(this->dataType() == sd::DataType::FLOAT32) { - for(Nd4jLong e = 0; e < len; e++) - printf("%.8f, ", this->bufferAsT()[e * ews]); - } - else if(this->dataType() == sd::DataType::DOUBLE) { - for(Nd4jLong e = 0; e < len; e++) - printf("%.8f, ", this->bufferAsT()[e * ews]); - } - else - throw std::invalid_argument("NDArray::printLinearBuffer: not implemented yet for this data type !"); - - printf("]\n"); - fflush(stdout); -} -////////////////////////////////////////////////////////////////////////// -static void printFormatted(NDArray const* arr, int depth, int limit) { - - if (arr->rankOf() == 1) { - printf("[ "); - for (Nd4jLong i = 0; i < arr->lengthOf(); ++i) { - if (arr->isR()) - printf("%f, ", arr->e(i)); - else if (arr->isZ()) - printf("%lld, ", arr->e(i)); - else if (arr->isB()) - printf("%s, ", arr->e(i)?"true":"false"); - else if (arr->isS()) { - printf("\"%s\", ", arr->e(i).c_str()); + for (Nd4jLong e = 0; e < limit; e++) { + printf("\"%s\"", this->e(e).c_str()); + if (e < limit - 1) + printf(", "); } } printf("]\n"); + fflush(stdout); } - else if (arr->rankOf() == 2) { - Nd4jLong rows = arr->rows(); - Nd4jLong cols = arr->columns(); - char* padding = new char[depth + 1]; - memset(padding, ' ', depth); - padding[depth] = 0; + +////////////////////////////////////////////////////////////////////////// +// print element by element consequently in a way they (elements) are stored in physical memory + void NDArray::printLinearBuffer() const { + + syncToHost(); + + const auto ews = this->ews() > 0 ? this->ews() : 1; + const auto len = this->lengthOf(); + printf("["); - for (Nd4jLong row = 0; row < rows; ++row) { - if (row && depth > 0) - printf("%s", padding); - printf("["); - Nd4jLong colLimit = cols > limit?cols:limit; - for (Nd4jLong col = 0; col < colLimit; ++col) { - if (col) - printf(", "); + + if (this->dataType() == sd::DataType::INT32) { + for(Nd4jLong e = 0; e < len; e++) + printf("%d, ", this->bufferAsT()[e * ews]); + } + else if(this->dataType() == sd::DataType::INT64) { + for(Nd4jLong e = 0; e < len; e++) + printf("%lld, ", this->bufferAsT()[e * ews]); + } + else if(this->dataType() == sd::DataType::FLOAT32) { + for(Nd4jLong e = 0; e < len; e++) + printf("%.8f, ", this->bufferAsT()[e * ews]); + } + else if(this->dataType() == sd::DataType::DOUBLE) { + for(Nd4jLong e = 0; e < len; e++) + printf("%.8f, ", this->bufferAsT()[e * ews]); + } + else + throw std::invalid_argument("NDArray::printLinearBuffer: not implemented yet for this data type !"); + + printf("]\n"); + fflush(stdout); + } +////////////////////////////////////////////////////////////////////////// + static void printFormatted(NDArray const* arr, int depth, int limit) { + + if (arr->rankOf() == 1) { + printf("[ "); + for (Nd4jLong i = 0; i < arr->lengthOf(); ++i) { if (arr->isR()) - printf("%f", arr->e(row, col)); + printf("%f, ", arr->e(i)); else if (arr->isZ()) - printf("%lld", arr->e(row, col)); + printf("%lld, ", arr->e(i)); else if (arr->isB()) - printf("%s", arr->e(row, col)?"true":"false"); + printf("%s, ", arr->e(i)?"true":"false"); else if (arr->isS()) { - printf("\"%s\"", arr->e(row * cols + col).c_str()); + printf("\"%s\", ", arr->e(i).c_str()); } } - if (row < rows - 1) - printf("]\n"); - else - printf("]"); + printf("]\n"); + } + else if (arr->rankOf() == 2) { + Nd4jLong rows = arr->rows(); + Nd4jLong cols = arr->columns(); + char* padding = new char[depth + 1]; + memset(padding, ' ', depth); + padding[depth] = 0; + printf("["); + for (Nd4jLong row = 0; row < rows; ++row) { + if (row && depth > 0) + printf("%s", padding); + printf("["); + Nd4jLong colLimit = cols > limit?cols:limit; + for (Nd4jLong col = 0; col < colLimit; ++col) { + if (col) + printf(", "); + if (arr->isR()) + printf("%f", arr->e(row, col)); + else if (arr->isZ()) + printf("%lld", arr->e(row, col)); + else if (arr->isB()) + printf("%s", arr->e(row, col)?"true":"false"); + else if (arr->isS()) { + printf("\"%s\"", arr->e(row * cols + col).c_str()); + } + } + if (row < rows - 1) + printf("]\n"); + else + printf("]"); + } + printf("]"); + if (padding) + delete [] padding; + } + else { + //std::unique_ptr arrs(arr->allTensorsAlongDimension({0})); + size_t restCount = 2; + printf("["); + restCount = ShapeUtils::getNumOfSubArrs(arr->shapeInfo(), {0}); + for (size_t arrIndex = 0; arrIndex < restCount; ++arrIndex) { + NDArray subArr = (*arr)(arrIndex, {0}); + printFormatted(&subArr, depth + 1, limit); + if (arrIndex < restCount - 1) { + for (Nd4jLong i = 1; i < arr->rankOf(); ++i) + printf("\n"); + for (Nd4jLong i = 0; i < depth - 2; ++i) + printf(" "); + } + } + printf("]"); } - printf("]"); - if (padding) - delete [] padding; } - else { - //std::unique_ptr arrs(arr->allTensorsAlongDimension({0})); - size_t restCount = 2; - printf("["); - restCount = ShapeUtils::getNumOfSubArrs(arr->shapeInfo(), {0}); - for (size_t arrIndex = 0; arrIndex < restCount; ++arrIndex) { - NDArray subArr = (*arr)(arrIndex, {0}); - printFormatted(&subArr, depth + 1, limit); - if (arrIndex < restCount - 1) { - for (Nd4jLong i = 1; i < arr->rankOf(); ++i) - printf("\n"); - for (Nd4jLong i = 0; i < depth - 2; ++i) - printf(" "); + +////////////////////////////////////////////////////////////////////////// + void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const { + + syncToHost(); + + Nd4jLong rank = this->rankOf(); + + bool rowFlag = (rank < 2) || (rank == 2 && this->sizeAt(0) == 1); + + if (msg) + printf("%s: ", msg); + + if (this->isEmpty()) { + printf("Empty\n"); + } + else if (this->rankOf() == 0) { + if (this->isZ()) + printf("%lld\n", this->e(0)); + else if (this->isR()) + printf("%.8f\n", this->e(0)); + else if (this->isB()) { + printf("%s\n", this->e(0)?"true":"false"); + } + else if (this->isS()) { + // todo do we need this + // printf("\"%lld\"\n", this->getOffset(e)); + printf("\"%s\"\n", this->e(0).c_str()); } } - printf("]"); - } -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::printIndexedBuffer(const char* msg, Nd4jLong limit) const { - - syncToHost(); - - Nd4jLong rank = this->rankOf(); - - bool rowFlag = (rank < 2) || (rank == 2 && this->sizeAt(0) == 1); - - if (msg) - printf("%s: ", msg); - - if (this->isEmpty()) { - printf("Empty\n"); - } - else if (this->rankOf() == 0) { - if (this->isZ()) - printf("%lld\n", this->e(0)); - else if (this->isR()) - printf("%.8f\n", this->e(0)); - else if (this->isB()) { - printf("%s\n", this->e(0)?"true":"false"); - } - else if (this->isS()) { - // todo do we need this - // printf("\"%lld\"\n", this->getOffset(e)); - printf("\"%s\"\n", this->e(0).c_str()); - } - } - else if (rowFlag && ews()==1) - printBuffer(nullptr, limit); - else { - if (msg) + else if (rowFlag && ews()==1) + printBuffer(nullptr, limit); + else { + if (msg) + printf("\n"); + printFormatted(this, 1, limit); printf("\n"); - printFormatted(this, 1, limit); - printf("\n"); + } + fflush(stdout); } - fflush(stdout); -} ////////////////////////////////////////////////////////////////////////// -template -void* NDArray::templatedPointerShift(const Nd4jLong offset) const { - return const_cast(reinterpret_cast(buffer()) + offset); -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void* NDArray::templatedPointerShift, (const Nd4jLong offset) const, LIBND4J_TYPES); + template + void* NDArray::templatedPointerShift(const Nd4jLong offset) const { + return const_cast(reinterpret_cast(buffer()) + offset); + } + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void* NDArray::templatedPointerShift, (const Nd4jLong offset) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // method makes copy of this array and applies to the copy transpose operation, this array remains unaffected -NDArray NDArray::transpose() const &{ - NDArray newArr(getDataBuffer(), ShapeDescriptor(shapeInfo()), getContext(), bufferOffset()); - newArr.transposei(); + NDArray NDArray::transpose() const &{ + NDArray newArr(getDataBuffer(), ShapeDescriptor(shapeInfo()), getContext(), bufferOffset()); + newArr.transposei(); - return newArr; -} + return newArr; + } ////////////////////////////////////////////////////////////////////////// // method makes copy of this array and applies to the copy transpose operation, this array remains unaffected -NDArray NDArray::transpose() && { + NDArray NDArray::transpose() && { - this->transposei(); - return std::move(*this); -} + this->transposei(); + return std::move(*this); + } //////////////////////////////////////////////////////////////////////// // method performs transpose operation based on this array and store result in target, this array remains unaffected -void NDArray::transpose(NDArray& target) const { + void NDArray::transpose(NDArray& target) const { - auto correctShape = ShapeUtils::evalTranspShapeInfo(*this, getContext()->getWorkspace()); - if(!shape::equalsStrict(correctShape, target.shapeInfo())) - throw std::runtime_error("NDArray::transpose method: the shapeInfo of target array is wrong !"); + auto correctShape = ShapeUtils::evalTranspShapeInfo(*this, getContext()->getWorkspace()); + if(!shape::equalsStrict(correctShape, target.shapeInfo())) + throw std::runtime_error("NDArray::transpose method: the shapeInfo of target array is wrong !"); - target._buffer = _buffer; - target._offset = _offset; - target._isView = true; -} + target._buffer = _buffer; + target._offset = _offset; + target._isView = true; + } //////////////////////////////////////////////////////////////////////// // This method applies in-place transpose to this array, so this array becomes transposed -void NDArray::transposei() { - std::vector perm; - for (int e = this->rankOf() - 1; e >= 0; e--) - perm.emplace_back(e); + void NDArray::transposei() { + std::vector perm; + for (int e = this->rankOf() - 1; e >= 0; e--) + perm.emplace_back(e); - this->permutei(perm); -} + this->permutei(perm); + } //////////////////////////////////////////////////////////////////////// -bool NDArray::equalsTo(const NDArray &other, double eps) const { - return equalsTo(&other, eps); -} + bool NDArray::equalsTo(const NDArray &other, double eps) const { + return equalsTo(&other, eps); + } ////////////////////////////////////////////////////////////////////////// -void NDArray::setAttached(bool reallyAttached) { - _isAttached = reallyAttached; -}; + void NDArray::setAttached(bool reallyAttached) { + _isAttached = reallyAttached; + }; ////////////////////////////////////////////////////////////////////////// // calculate strides -void NDArray::updateStrides(const char order) { - throw std::runtime_error("Forbidden method"); -} + void NDArray::updateStrides(const char order) { + throw std::runtime_error("Forbidden method"); + } ////////////////////////////////////////////////////////////////////////// // set new order and shape in case of suitable array length -bool NDArray::reshapei(const char order, const std::initializer_list& shape, const bool copyToNewBuff) { - std::vector vShape(shape); - return reshapei(order, vShape, copyToNewBuff); -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::reshapei(const std::initializer_list& shape, const bool copyToNewBuff) { - return reshapei(ordering(), shape, copyToNewBuff); -} - -////////////////////////////////////////////////////////////////////////// -bool NDArray::reshapei(const std::vector& shape, const bool copyToNewBuff) { - return reshapei(ordering(), shape, copyToNewBuff); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::enforce(const std::initializer_list &dimensions, char order) { - std::vector dims(dimensions); - enforce(dims, order); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::enforce(std::vector &dimensions, char o) { - - Nd4jLong prod = 1; - for (int e = 0; e < dimensions.size(); e++) - prod *= dimensions[e]; - - if (prod != this->lengthOf()) { - std::string current = ShapeUtils::shapeAsString(this); - std::string enforced = ShapeUtils::shapeAsString(dimensions); - nd4j_printf("Can't enforce new shape, lengths mismatch. Original shape: %s; Requested shape: %s\n", current.c_str(), enforced.c_str()); - throw std::runtime_error("Incompatible shape"); + bool NDArray::reshapei(const char order, const std::initializer_list& shape, const bool copyToNewBuff) { + std::vector vShape(shape); + return reshapei(order, vShape, copyToNewBuff); } - char order = o == 'a' ? this->ordering() : o; - setShapeInfo(ShapeDescriptor(dataType(), order, dimensions)); -} +////////////////////////////////////////////////////////////////////////// + bool NDArray::reshapei(const std::initializer_list& shape, const bool copyToNewBuff) { + return reshapei(ordering(), shape, copyToNewBuff); + } ////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::argMax(std::initializer_list dimensions) { + bool NDArray::reshapei(const std::vector& shape, const bool copyToNewBuff) { + return reshapei(ordering(), shape, copyToNewBuff); + } - if (isS()) - throw std::runtime_error("NDArray::argMax: you can't use this method on String array!"); +////////////////////////////////////////////////////////////////////////// + void NDArray::enforce(const std::initializer_list &dimensions, char order) { + std::vector dims(dimensions); + enforce(dims, order); + } - if (dimensions.size() == 0) { - Nd4jLong max = 0; - auto mv = -DataTypeUtils::max(); - for (Nd4jLong e = 0; e < this->lengthOf(); e++) { - auto val = this->e(e); - if (mv < val) { - mv = val; - max = e; - } +////////////////////////////////////////////////////////////////////////// + void NDArray::enforce(std::vector &dimensions, char o) { + + Nd4jLong prod = 1; + for (int e = 0; e < dimensions.size(); e++) + prod *= dimensions[e]; + + if (prod != this->lengthOf()) { + std::string current = ShapeUtils::shapeAsString(this); + std::string enforced = ShapeUtils::shapeAsString(dimensions); + nd4j_printf("Can't enforce new shape, lengths mismatch. Original shape: %s; Requested shape: %s\n", current.c_str(), enforced.c_str()); + throw std::runtime_error("Incompatible shape"); } - return max; + + char order = o == 'a' ? this->ordering() : o; + setShapeInfo(ShapeDescriptor(dataType(), order, dimensions)); + } + +////////////////////////////////////////////////////////////////////////// + Nd4jLong NDArray::argMax(std::initializer_list dimensions) { + + if (isS()) + throw std::runtime_error("NDArray::argMax: you can't use this method on String array!"); + + if (dimensions.size() == 0) { + Nd4jLong max = 0; + auto mv = -DataTypeUtils::max(); + for (Nd4jLong e = 0; e < this->lengthOf(); e++) { + auto val = this->e(e); + if (mv < val) { + mv = val; + max = e; + } + } + return max; + } + else + throw std::runtime_error("Not implemented yet"); } - else - throw std::runtime_error("Not implemented yet"); -} ////////////////////////////////////////////////////////////////////////// // create new array with corresponding order and shape, new array will point to the same _buffer as this array -NDArray NDArray::reshape(const char order, const std::vector& shape, const bool copyToNewBuff) const & { + NDArray NDArray::reshape(const char order, const std::vector& shape, const bool copyToNewBuff) const & { - NDArray newArr(getDataBuffer(), ShapeDescriptor(shapeInfo()), getContext(), bufferOffset()); - newArr.reshapei(order, shape, copyToNewBuff); + NDArray newArr(getDataBuffer(), ShapeDescriptor(shapeInfo()), getContext(), bufferOffset()); + newArr.reshapei(order, shape, copyToNewBuff); - return newArr; -} + return newArr; + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::reshape(const char order, const std::vector& shape, const bool copyToNewBuff) && { + NDArray NDArray::reshape(const char order, const std::vector& shape, const bool copyToNewBuff) && { - this->reshapei(order, shape, copyToNewBuff); - return std::move(*this); -} + this->reshapei(order, shape, copyToNewBuff); + return std::move(*this); + } ////////////////////////////////////////////////////////////////////////// // change an array by repeating it the number of times given by reps. -void NDArray::tilei(const std::vector& reps) { - *this = this->tile(reps); -} + void NDArray::tilei(const std::vector& reps) { + *this = this->tile(reps); + } ////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::sizeAt(const int dim) const { + Nd4jLong NDArray::sizeAt(const int dim) const { - if (dim >= this->rankOf() || dim < -this->rankOf()) - throw std::runtime_error("NDArray::sizeAt: bad size index requested"); + if (dim >= this->rankOf() || dim < -this->rankOf()) + throw std::runtime_error("NDArray::sizeAt: bad size index requested"); - if (dim >= 0) - return shape::shapeOf(_shapeInfo)[dim]; - else - return shape::shapeOf(_shapeInfo)[this->rankOf() + dim]; -} + if (dim >= 0) + return shape::shapeOf(_shapeInfo)[dim]; + else + return shape::shapeOf(_shapeInfo)[this->rankOf() + dim]; + } ////////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::strideAt(const int dim) const { + Nd4jLong NDArray::strideAt(const int dim) const { - if (dim >= this->rankOf() || dim < -this->rankOf()) - throw std::runtime_error("NDArray::strideAt: Bad size index requested"); + if (dim >= this->rankOf() || dim < -this->rankOf()) + throw std::runtime_error("NDArray::strideAt: Bad size index requested"); - if (dim >= 0) - return shape::stride(_shapeInfo)[dim]; - else - return shape::stride(_shapeInfo)[this->rankOf() + dim]; -} + if (dim >= 0) + return shape::stride(_shapeInfo)[dim]; + else + return shape::stride(_shapeInfo)[this->rankOf() + dim]; + } ////////////////////////////////////////////////////////////////////////// -bool NDArray::permutei(const std::initializer_list& dimensions) { - std::vector vec(dimensions); - return permutei(vec); -} + bool NDArray::permutei(const std::initializer_list& dimensions) { + std::vector vec(dimensions); + return permutei(vec); + } ////////////////////////////////////////////////////////////////////////// -bool NDArray::permutei(const std::vector& dimensions) { - return permutei(dimensions.data(), rankOf()); -} + bool NDArray::permutei(const std::vector& dimensions) { + return permutei(dimensions.data(), rankOf()); + } ////////////////////////////////////////////////////////////////////////// -bool NDArray::permutei(const std::initializer_list& dimensions) { - std::vector vec(dimensions); - std::vector ivec(dimensions.size()); + bool NDArray::permutei(const std::initializer_list& dimensions) { + std::vector vec(dimensions); + std::vector ivec(dimensions.size()); - for (int e = 0; e < vec.size(); e++) - ivec[e] = static_cast(vec[e]); + for (int e = 0; e < vec.size(); e++) + ivec[e] = static_cast(vec[e]); - return permutei(ivec); -} + return permutei(ivec); + } ////////////////////////////////////////////////////////////////////////// -bool NDArray::permutei(const std::vector& dimensions) { + bool NDArray::permutei(const std::vector& dimensions) { - std::vector ivec(dimensions.size()); + std::vector ivec(dimensions.size()); - for (int e = 0; e < dimensions.size(); e++) - ivec[e] = dimensions[e]; + for (int e = 0; e < dimensions.size(); e++) + ivec[e] = dimensions[e]; - return permutei(ivec.data(), rankOf()); -} + return permutei(ivec.data(), rankOf()); + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const int* dimensions, const int rank) const & { + NDArray NDArray::permute(const int* dimensions, const int rank) const & { - // evaluate shapeInfo for output (permuted) array ret - auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); - NDArray ret(getDataBuffer(), ShapeDescriptor(shapeInfoPermuted), getContext(), bufferOffset()); - ret._isView = true; - return ret; -} + // evaluate shapeInfo for output (permuted) array ret + auto shapeInfoPermuted = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); + NDArray ret(getDataBuffer(), ShapeDescriptor(shapeInfoPermuted), getContext(), bufferOffset()); + ret._isView = true; + return ret; + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const int* dimensions, const int rank) && { + NDArray NDArray::permute(const int* dimensions, const int rank) && { - this->permutei(dimensions, rank); - return std::move(*this); -} + this->permutei(dimensions, rank); + return std::move(*this); + } ///////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const &{ - int tempDims[MAX_RANK]; - shape::convertT(const_cast(dimensions), tempDims, rank); - return permute(tempDims, rank); -} + NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) const &{ + int tempDims[MAX_RANK]; + shape::convertT(const_cast(dimensions), tempDims, rank); + return permute(tempDims, rank); + } ///////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && { + NDArray NDArray::permute(const Nd4jLong* dimensions, const int rank) && { - this->permutei(dimensions, rank); - return std::move(*this); -} + this->permutei(dimensions, rank); + return std::move(*this); + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) const &{ + NDArray NDArray::permute(const std::vector& dimensions) const &{ - return permute(dimensions.data(), rankOf()); -} + return permute(dimensions.data(), rankOf()); + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) && { + NDArray NDArray::permute(const std::vector& dimensions) && { - this->permutei(dimensions); - return std::move(*this); -} + this->permutei(dimensions); + return std::move(*this); + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) const & { + NDArray NDArray::permute(const std::vector& dimensions) const & { - return permute(dimensions.data(), rankOf()); -} + return permute(dimensions.data(), rankOf()); + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::vector& dimensions) && { + NDArray NDArray::permute(const std::vector& dimensions) && { - this->permutei(dimensions); - return std::move(*this); -} + this->permutei(dimensions); + return std::move(*this); + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) const &{ + NDArray NDArray::permute(const std::initializer_list& dimensions) const &{ - std::vector vec(dimensions); - return permute(vec); -} + std::vector vec(dimensions); + return permute(vec); + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) && { + NDArray NDArray::permute(const std::initializer_list& dimensions) && { - this->permutei(dimensions); - return std::move(*this); -} + this->permutei(dimensions); + return std::move(*this); + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) const & { - std::vector vec(dimensions); - return permute(vec); -} + NDArray NDArray::permute(const std::initializer_list& dimensions) const & { + std::vector vec(dimensions); + return permute(vec); + } ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::permute(const std::initializer_list& dimensions) && { + NDArray NDArray::permute(const std::initializer_list& dimensions) && { - this->permutei(dimensions); - return std::move(*this); -} + this->permutei(dimensions); + return std::move(*this); + } ////////////////////////////////////////////////////////////////////////// -void NDArray::permute(const int* dimensions, const int rank, NDArray& target) const { - if (!nonNull() || !target.nonNull() || rank != rankOf() || rank != target.rankOf() ) - throw std::runtime_error("NDArray::permute method: either arrays are nullptr or ranks are not suitable!"); + void NDArray::permute(const int* dimensions, const int rank, NDArray& target) const { + if (!nonNull() || !target.nonNull() || rank != rankOf() || rank != target.rankOf() ) + throw std::runtime_error("NDArray::permute method: either arrays are nullptr or ranks are not suitable!"); - auto shapeInfoNew = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, target.getContext()->getWorkspace()); + auto shapeInfoNew = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, target.getContext()->getWorkspace()); - target.setShapeInfo(shapeInfoNew); - target._buffer = _buffer; - target._offset = _offset; -} + target.setShapeInfo(shapeInfoNew); + target._buffer = _buffer; + target._offset = _offset; + } ////////////////////////////////////////////////////////////////////////// -void NDArray::permute(const Nd4jLong *dimensions, const int rank, NDArray& target) const { - if (!nonNull() || !target.nonNull() || rank != rankOf() || rank != target.rankOf() ) - throw std::runtime_error("NDArray::permute method: either arrays are nullptr or ranks are not suitable!"); + void NDArray::permute(const Nd4jLong *dimensions, const int rank, NDArray& target) const { + if (!nonNull() || !target.nonNull() || rank != rankOf() || rank != target.rankOf() ) + throw std::runtime_error("NDArray::permute method: either arrays are nullptr or ranks are not suitable!"); - auto shapeInfoNew = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, target.getContext()->getWorkspace()); + auto shapeInfoNew = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, target.getContext()->getWorkspace()); - target.setShapeInfo(shapeInfoNew); - target._buffer = _buffer; - target._offset = _offset; -} + target.setShapeInfo(shapeInfoNew); + target._buffer = _buffer; + target._offset = _offset; + } ////////////////////////////////////////////////////////////////////////// -void NDArray::permute(const std::vector& dimensions, NDArray& target) const { - permute(dimensions.data(), rankOf(), target); -} + void NDArray::permute(const std::vector& dimensions, NDArray& target) const { + permute(dimensions.data(), rankOf(), target); + } ////////////////////////////////////////////////////////////////////////// -void NDArray::permute(const std::vector& dimensions, NDArray& target) const { - permute(dimensions.data(), rankOf(), target); -} + void NDArray::permute(const std::vector& dimensions, NDArray& target) const { + permute(dimensions.data(), rankOf(), target); + } ////////////////////////////////////////////////////////////////////////// // check whether array is identity matrix -bool NDArray::isIdentityMatrix() { - if (isS()) - throw std::runtime_error("NDArray::isIdentityMatrix: you can't use this method on String array!"); - if(rankOf() !=2 || rows() != columns()) - throw std::runtime_error("isIdentityMatrix method: matrix must be square and have rank = 2 !"); + bool NDArray::isIdentityMatrix() { + if (isS()) + throw std::runtime_error("NDArray::isIdentityMatrix: you can't use this method on String array!"); + if(rankOf() !=2 || rows() != columns()) + throw std::runtime_error("isIdentityMatrix method: matrix must be square and have rank = 2 !"); - const double eps = 1e-5f; - for(Nd4jLong i=0; i(i,i) - 1.f) > eps) - return false; - - for(Nd4jLong i=0; i(i,j)) > eps) + const double eps = 1e-5f; + for(Nd4jLong i=0; i(i,i) - 1.f) > eps) return false; - } + + for(Nd4jLong i=0; i(i,j)) > eps) + return false; + } + } + return true; } - return true; -} ////////////////////////////////////////////////////////////////////////// // check whether array is unitary matrix -bool NDArray::isUnitary() { - if (isS()) - throw std::runtime_error("NDArray::isUnitary: you can't use this method on String array!"); - if(rankOf() != 2 || rows() != columns()) - throw std::runtime_error("isUnitary method: matrix must be square and have rank = 2 !"); + bool NDArray::isUnitary() { + if (isS()) + throw std::runtime_error("NDArray::isUnitary: you can't use this method on String array!"); + if(rankOf() != 2 || rows() != columns()) + throw std::runtime_error("isUnitary method: matrix must be square and have rank = 2 !"); - auto tr = this->transpose(); - auto trMul = MmulHelper::mmul(this, &tr, nullptr, 1.f, 0.f); + auto tr = this->transpose(); + auto trMul = MmulHelper::mmul(this, &tr, nullptr, 1.f, 0.f); - bool result = trMul->isIdentityMatrix(); - delete trMul; + bool result = trMul->isIdentityMatrix(); + delete trMul; - return result; -} - -////////////////////////////////////////////////////////////////////////// -template <> -const std::string* ND4J_EXPORT NDArray::bufferAsT() const { - throw std::runtime_error("This method is NOT supposed to be used"); -} - -////////////////////////////////////////////////////////////////////////// -template -const T* NDArray::bufferAsT() const { - // FIXME: do we REALLY want sync here? - // syncToHost(); - - return reinterpret_cast(buffer()); -} -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT const, * NDArray::bufferAsT() const, LIBND4J_TYPES); - -template -T* NDArray::bufferAsT() { - syncToHost(); - return reinterpret_cast(buffer()); -} -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT, * NDArray::bufferAsT(), LIBND4J_TYPES); - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::subarray(IndicesList& idx) const { - - const int idxSize = idx.size(); - if (idxSize != this->rankOf()) - throw std::runtime_error("NDArray::subarray: number of indices should match"); - - std::vector indexes(3 * idxSize); - - // convert IndicesList to vector - for (int d = 0; d < idxSize; ++d) { - - if (idx.at(d)->isAll()) { - indexes[3 * d] = 0; // first - indexes[3 * d + 1] = 0; // last - indexes[3 * d + 2] = 1; // stride - } - else if (idx.at(d)->isPoint()) { - indexes[3 * d] = idx.at(d)->getIndices().at(0); // first - indexes[3 * d + 1] = indexes[3 * d] + 1; // last - indexes[3 * d + 2] = 1; // stride - } - else if (idx.at(d)->isInterval()) { - indexes[3 * d] = idx.at(d)->getIndices().at(0); // first - indexes[3 * d + 1] = idx.at(d)->getIndices().size();// last - indexes[3 * d + 2] = idx.at(d)->stride(); // stride - } - else { - indexes[3 * d] = idx.at(d)->getIndices().at(0); // first - indexes[3 * d + 1] = idx.at(d)->getIndices().at(1); // last - indexes[3 * d + 2] = idx.at(d)->getIndices().at(2); // stride - } - } - return NDArray((*this)(indexes, true, true)); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::subarray(const std::initializer_list& idx) const { - - const int idxSize = idx.size(); - if (idxSize != this->rankOf()) - throw std::runtime_error("NDArray::subarray: number of indices should match the array rank"); - - std::vector indexes(3 * idxSize); - - // convert NDIndex to vector - int d = 0; - for (const auto& item : idx) { - - if (item->isAll()) { - indexes[3 * d] = 0; // first - indexes[3 * d + 1] = 0; // last - indexes[3 * d + 2] = 1; // stride - } - else if (item->isPoint()) { - indexes[3 * d] = item->getIndices().at(0); // first - indexes[3 * d + 1] = indexes[3 * d] + 1; // last - indexes[3 * d + 2] = 1; // stride - } - else if (item->isInterval()) { - indexes[3 * d] = item->getIndices().at(0); // first - indexes[3 * d + 1] = item->getIndices().size(); // last - indexes[3 * d + 2] = item->stride(); // stride - } - else { - indexes[3 * d] = item->getIndices().at(0); // first - indexes[3 * d + 1] = item->getIndices().at(1); // last - indexes[3 * d + 2] = item->getIndices().at(2); // stride - } - ++d; + return result; } - // release NDIndices - for (auto i: idx) - delete i; - - return NDArray((*this)(indexes, true, true)); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::subarray(const Intervals& idx) const { - - const int idxSize = idx.size(); - if (idxSize != this->rankOf()) - throw std::runtime_error("NDArray::subarray: number of indices should match the rank of array!"); - - std::vector indexes(2 * idxSize); - - // convert Intervals to vector - for (int d = 0; d < idxSize; ++d) { - - if (idx[d].empty()) { - indexes[2 * d] = 0; // first - indexes[2 * d + 1] = 0; // last - } - else { - indexes[2 * d] = idx[d][0]; // first - indexes[2 * d + 1] = idx[d][1]; // last - } +////////////////////////////////////////////////////////////////////////// + template <> + const std::string* ND4J_EXPORT NDArray::bufferAsT() const { + throw std::runtime_error("This method is NOT supposed to be used"); } - return NDArray((*this)(indexes, true)); -} +////////////////////////////////////////////////////////////////////////// + template + const T* NDArray::bufferAsT() const { + // FIXME: do we REALLY want sync here? + // syncToHost(); + + return reinterpret_cast(buffer()); + } + BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT const, * NDArray::bufferAsT() const, LIBND4J_TYPES); + + template + T* NDArray::bufferAsT() { + syncToHost(); + return reinterpret_cast(buffer()); + } + BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT, * NDArray::bufferAsT(), LIBND4J_TYPES); + +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::subarray(IndicesList& idx) const { + + const int idxSize = idx.size(); + if (idxSize != this->rankOf()) + throw std::runtime_error("NDArray::subarray: number of indices should match"); + + std::vector indexes(3 * idxSize); + + // convert IndicesList to vector + for (int d = 0; d < idxSize; ++d) { + + if (idx.at(d)->isAll()) { + indexes[3 * d] = 0; // first + indexes[3 * d + 1] = 0; // last + indexes[3 * d + 2] = 1; // stride + } + else if (idx.at(d)->isPoint()) { + indexes[3 * d] = idx.at(d)->getIndices().at(0); // first + indexes[3 * d + 1] = indexes[3 * d] + 1; // last + indexes[3 * d + 2] = 1; // stride + } + else if (idx.at(d)->isInterval()) { + indexes[3 * d] = idx.at(d)->getIndices().at(0); // first + indexes[3 * d + 1] = idx.at(d)->getIndices().size();// last + indexes[3 * d + 2] = idx.at(d)->stride(); // stride + } + else { + indexes[3 * d] = idx.at(d)->getIndices().at(0); // first + indexes[3 * d + 1] = idx.at(d)->getIndices().at(1); // last + indexes[3 * d + 2] = idx.at(d)->getIndices().at(2); // stride + } + } + return NDArray((*this)(indexes, true, true)); + } + +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::subarray(const std::initializer_list& idx) const { + + const int idxSize = idx.size(); + if (idxSize != this->rankOf()) + throw std::runtime_error("NDArray::subarray: number of indices should match the array rank"); + + std::vector indexes(3 * idxSize); + + // convert NDIndex to vector + int d = 0; + for (const auto& item : idx) { + + if (item->isAll()) { + indexes[3 * d] = 0; // first + indexes[3 * d + 1] = 0; // last + indexes[3 * d + 2] = 1; // stride + } + else if (item->isPoint()) { + indexes[3 * d] = item->getIndices().at(0); // first + indexes[3 * d + 1] = indexes[3 * d] + 1; // last + indexes[3 * d + 2] = 1; // stride + } + else if (item->isInterval()) { + indexes[3 * d] = item->getIndices().at(0); // first + indexes[3 * d + 1] = item->getIndices().size(); // last + indexes[3 * d + 2] = item->stride(); // stride + } + else { + indexes[3 * d] = item->getIndices().at(0); // first + indexes[3 * d + 1] = item->getIndices().at(1); // last + indexes[3 * d + 2] = item->getIndices().at(2); // stride + } + ++d; + } + + // release NDIndices + for (auto i: idx) + delete i; + + return NDArray((*this)(indexes, true, true)); + } + +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::subarray(const Intervals& idx) const { + + const int idxSize = idx.size(); + if (idxSize != this->rankOf()) + throw std::runtime_error("NDArray::subarray: number of indices should match the rank of array!"); + + std::vector indexes(2 * idxSize); + + // convert Intervals to vector + for (int d = 0; d < idxSize; ++d) { + + if (idx[d].empty()) { + indexes[2 * d] = 0; // first + indexes[2 * d + 1] = 0; // last + } + else { + indexes[2 * d] = idx[d][0]; // first + indexes[2 * d + 1] = idx[d][1]; // last + } + } + + return NDArray((*this)(indexes, true)); + } ////////////////////////////////////////////////////////////////////////// -template -NDArray NDArray::asT() const{ + template + NDArray NDArray::asT() const{ - auto result = isScalar() ? NDArray('c', {}, std::vector{0.}, DataTypeUtils::fromT(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT(), this->getContext()); + auto result = isScalar() ? NDArray('c', {}, std::vector{0.}, DataTypeUtils::fromT(), this->getContext()) : NDArray(ordering(), getShapeAsVector(), DataTypeUtils::fromT(), this->getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformAny(getContext(), transform::AnyOps::Assign, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&result}, {this}); - return result; -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asT, () const, LIBND4J_TYPES); + return result; + } + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asT, () const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -template -NDArray NDArray::asS() const { + template + NDArray NDArray::asS() const { - if (!isS()) - throw std::runtime_error("NDArray::asS: you can use this method only for String array!"); + if (!isS()) + throw std::runtime_error("NDArray::asS: you can use this method only for String array!"); - auto dtype = DataTypeUtils::fromT(); + auto dtype = DataTypeUtils::fromT(); - if (!(DataTypeUtils::isS(dtype))) - throw std::invalid_argument("NDArray::asS: invalid DataType used"); + if (!(DataTypeUtils::isS(dtype))) + throw std::invalid_argument("NDArray::asS: invalid DataType used"); - if (dtype == dataType()) { + if (dtype == dataType()) { + + Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + const auto nInputoffsets = bufferAsT(); + std::shared_ptr pBuffer = std::make_shared(offsetsLength + nInputoffsets[lengthOf()], dtype, getContext()->getWorkspace(), true); + + NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext()); + res.setAttached(getContext()->getWorkspace() != nullptr); + + preparePrimaryUse({ &res }, { this }); + memcpy(res.bufferAsT(), nInputoffsets, offsetsLength); + auto data = res.bufferAsT() + offsetsLength; + const auto inData = bufferAsT() + offsetsLength; + memcpy(data, inData, nInputoffsets[lengthOf()]); + + registerPrimaryUse({ &res }, { this }); + return res; + } Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + + std::vector offsets(lengthOf() + 1); + const auto nInputoffsets = bufferAsT(); - std::shared_ptr pBuffer = std::make_shared(offsetsLength + nInputoffsets[lengthOf()], dtype, getContext()->getWorkspace(), true); + + Nd4jLong start = 0, stop = 0; + Nd4jLong dataLength = 0; + + auto data = bufferAsT() + offsetsLength; + for (Nd4jLong e = 0; e < lengthOf(); e++) { + offsets[e] = dataLength; + start = nInputoffsets[e]; + stop = nInputoffsets[e + 1]; + if (dataType() == DataType::UTF8) { + dataLength += (dtype == DataType::UTF16) ? unicode::offsetUtf8StringInUtf16(data + start, stop) + : unicode::offsetUtf8StringInUtf32(data + start, stop); + } + else if (dataType() == DataType::UTF16) { + dataLength += (dtype == DataType::UTF32) ? unicode::offsetUtf16StringInUtf32(data + start, (stop / sizeof(char16_t)) ) + : unicode::offsetUtf16StringInUtf8(data + start, (stop / sizeof(char16_t))); + } + else { + dataLength += (dtype == DataType::UTF16) ? unicode::offsetUtf32StringInUtf16(data + start, (stop / sizeof(char32_t))) + : unicode::offsetUtf32StringInUtf8(data + start, (stop / sizeof(char32_t))); + } + } + offsets[lengthOf()] = dataLength; + + std::shared_ptr pBuffer = std::make_shared(offsetsLength + dataLength, dtype, getContext()->getWorkspace(), true); NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext()); res.setAttached(getContext()->getWorkspace() != nullptr); preparePrimaryUse({ &res }, { this }); - memcpy(res.bufferAsT(), nInputoffsets, offsetsLength); - auto data = res.bufferAsT() + offsetsLength; + + memcpy(res.bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); + + auto outData = res.bufferAsT() + offsetsLength; const auto inData = bufferAsT() + offsetsLength; - memcpy(data, inData, nInputoffsets[lengthOf()]); - registerPrimaryUse({ &res }, { this }); - return res; - } - - Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - - std::vector offsets(lengthOf() + 1); - - const auto nInputoffsets = bufferAsT(); - - Nd4jLong start = 0, stop = 0; - Nd4jLong dataLength = 0; - - auto data = bufferAsT() + offsetsLength; - for (Nd4jLong e = 0; e < lengthOf(); e++) { - offsets[e] = dataLength; - start = nInputoffsets[e]; - stop = nInputoffsets[e + 1]; - if (dataType() == DataType::UTF8) { - dataLength += (dtype == DataType::UTF16) ? unicode::offsetUtf8StringInUtf16(data + start, stop) - : unicode::offsetUtf8StringInUtf32(data + start, stop); - } - else if (dataType() == DataType::UTF16) { - dataLength += (dtype == DataType::UTF32) ? unicode::offsetUtf16StringInUtf32(data + start, (stop / sizeof(char16_t)) ) - : unicode::offsetUtf16StringInUtf8(data + start, (stop / sizeof(char16_t))); - } - else { - dataLength += (dtype == DataType::UTF16) ? unicode::offsetUtf32StringInUtf16(data + start, (stop / sizeof(char32_t))) - : unicode::offsetUtf32StringInUtf8(data + start, (stop / sizeof(char32_t))); - } - } - offsets[lengthOf()] = dataLength; - - std::shared_ptr pBuffer = std::make_shared(offsetsLength + dataLength, dtype, getContext()->getWorkspace(), true); - - NDArray res(pBuffer, ShapeDescriptor(dtype, ordering(), getShapeAsVector()), getContext()); - res.setAttached(getContext()->getWorkspace() != nullptr); - - preparePrimaryUse({ &res }, { this }); - - memcpy(res.bufferAsT(), offsets.data(), offsets.size() * sizeof(Nd4jLong)); - - auto outData = res.bufferAsT() + offsetsLength; - const auto inData = bufferAsT() + offsetsLength; - - auto func = PRAGMA_THREADS_FOR{ - for (int e = start; e < stop; e++) { - auto cdata = outData + offsets[e]; - auto end = nInputoffsets[e + 1]; - auto idata = inData + nInputoffsets[e]; - if (dtype == DataType::UTF16) { - if (dataType() == DataType::UTF8) { - unicode::utf8to16(idata, outData, end); - } - else { - unicode::utf32to16(idata, outData, (end / sizeof(char32_t))); - } - } - else if (dtype == DataType::UTF32) { - if (dataType() == DataType::UTF8) { - unicode::utf8to32(idata, cdata, end); - } - else { - unicode::utf16to32(idata, outData, (end / sizeof(char16_t))); - } - } - else { - if (dataType() == DataType::UTF16) { - unicode::utf16to8(idata, outData, (end / sizeof(char16_t))); - } - else { - unicode::utf32to8(idata, outData, (end / sizeof(char32_t))); - } - } - } - }; - - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - registerPrimaryUse({ &res }, { this }); - - return res; -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asS, () const, LIBND4J_STRINGTYPES); - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::asT(DataType dtype) const { - - if (isS() && !DataTypeUtils::isS(dtype)) - throw std::runtime_error("NDArray::asT: you can't use this method on String array with not string DataType!"); - - if (!isS() && DataTypeUtils::isS(dtype)) - throw std::runtime_error("NDArray::asT: you can't use this method on not String array with string DataType!"); - - if (isS()){ - BUILD_SINGLE_SELECTOR(dtype, return asS, (), LIBND4J_STRINGTYPES); - } else { - BUILD_SINGLE_SELECTOR(dtype, return asT, (), LIBND4J_TYPES); - } - - return NDArray(); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::cast(DataType dtype) const { - - if (isS() && !DataTypeUtils::isS(dtype)) - throw std::runtime_error("NDArray::cast: you can't use this method on String array with not string DataType!"); - - if (!isS() && DataTypeUtils::isS(dtype)) - throw std::runtime_error("NDArray::cast: you can't use this method on not String array with string DataType!"); - - return this->asT(dtype); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::cast(NDArray& target, DataType dtype) { - if (isS()) - throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); - // TODO: to be implemented properly - target.assign(this); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::operator+=(const NDArray& other) { - - if (isS()) - throw std::runtime_error("NDArray::operator+=: you can't use this method on String array!"); - if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw sd::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType()); - - if (this->lengthOf() != 1 && other.lengthOf() == 1) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else{ - const Nd4jLong *bShape = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) - throw std::invalid_argument("NDArray::operator+=: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(shape::equalsTypesAndShapesSoft(shapeInfo(), bShape)) { - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), other, *this, false); - } - else { - NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), other, result, false); - *this = std::move(result); // move assignment operator, zero cost copy - } - } -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::operator-=(const NDArray& other) { - if (isS()) - throw std::runtime_error("NDArray::operator-=: you can't use this method on String array!"); - - if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw sd::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType()); - - if (lengthOf() != 1 && other.lengthOf() == 1) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else{ - const Nd4jLong *bShape = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) - throw std::invalid_argument("NDArray::operator-=: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(shape::equalsTypesAndShapesSoft(shapeInfo(), bShape)) { - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), other, *this, false); - } - else { - NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), other, result, false); - *this = std::move(result); // move assignment operator, zero cost copy - } - } -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::operator*=(const NDArray& other) { - if (isS()) - throw std::runtime_error("NDArray::operator*=: you can't use this method on String array!"); - if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) - throw sd::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType()); - - if (lengthOf() != 1 && other.lengthOf() == 1) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else{ - const Nd4jLong *bShape = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) - throw std::invalid_argument("NDArray::operator*=: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), other, *this, false); - } - else { - NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), other, result, false); - *this = std::move(result); // move assignment operator, zero cost copy - } - } -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::operator/=(const NDArray& other) { - if (isS() || other.isS()) - throw std::runtime_error("NDArray::operator/=: you can't use this method on String array!"); - if (other.isB()) - throw std::runtime_error("NDArray::operator/=: you can't divide by bool array!"); - - if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType()) { - throw sd::datatype_exception::build("NDArray operator/=: Cannot divide different types", this->dataType(), other.dataType()); - } - - if (lengthOf() != 1 && other.lengthOf() == 1) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); - } - else{ - const Nd4jLong *bShape = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) - throw std::invalid_argument("NDArray::operator/=: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), other, *this, false); - } - else { - NDArray result(bShape, true, getContext()); - this->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), other, result, false); - *this = std::move(result); // move assignment operator, zero cost copy - } - } -} - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::operator+=(const T value) { - if (isS()) - throw std::runtime_error("NDArray::operator+=: you can't use this method on String array!"); - - auto other = NDArrayFactory::create(this->dataType(), value, getContext()); - - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); -} -template ND4J_EXPORT void NDArray::operator+=(const double value); -template ND4J_EXPORT void NDArray::operator+=(const float value); -template ND4J_EXPORT void NDArray::operator+=(const float16 value); -template ND4J_EXPORT void NDArray::operator+=(const bfloat16 value); -template ND4J_EXPORT void NDArray::operator+=(const Nd4jLong value); -template ND4J_EXPORT void NDArray::operator+=(const int value); -template ND4J_EXPORT void NDArray::operator+=(const bool value); - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::operator-=(const T value) { - if (isS()) - throw std::runtime_error("NDArray::operator-=: you can't use this method on String array!"); - - auto other = NDArrayFactory::create(dataType(), value, getContext()); - - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); -} -template ND4J_EXPORT void NDArray::operator-=(const double value); -template ND4J_EXPORT void NDArray::operator-=(const float value); -template ND4J_EXPORT void NDArray::operator-=(const float16 value); -template ND4J_EXPORT void NDArray::operator-=(const bfloat16 value); -template ND4J_EXPORT void NDArray::operator-=(const Nd4jLong value); -template ND4J_EXPORT void NDArray::operator-=(const int value); -template ND4J_EXPORT void NDArray::operator-=(const bool value); - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::operator*=(const T scalar) { - if (isS()) - throw std::runtime_error("NDArray::operator*=: you can't use this method on String array!"); - - auto other = NDArrayFactory::create(this->dataType(), scalar, getContext()); - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); -} -template ND4J_EXPORT void NDArray::operator*=(const double scalar); -template ND4J_EXPORT void NDArray::operator*=(const float scalar); -template ND4J_EXPORT void NDArray::operator*=(const float16 scalar); -template ND4J_EXPORT void NDArray::operator*=(const bfloat16 scalar); -template ND4J_EXPORT void NDArray::operator*=(const Nd4jLong scalar); -template ND4J_EXPORT void NDArray::operator*=(const int scalar); -template ND4J_EXPORT void NDArray::operator*=(const int16_t scalar); -template ND4J_EXPORT void NDArray::operator*=(const int8_t scalar); -template ND4J_EXPORT void NDArray::operator*=(const uint8_t scalar); -template ND4J_EXPORT void NDArray::operator*=(const bool scalar); - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::operator/=(const T scalar) { - if (isS()) - throw std::runtime_error("NDArray::operator/=: you can't use this method on String array!"); - - auto other = NDArrayFactory::create(this->dataType(), scalar, getContext()); - NDArray::prepareSpecialUse({this}, {this, &other}); - NativeOpExecutioner::execScalar(getContext(), sd::scalar::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({this}, {this, &other}); -} -template ND4J_EXPORT void NDArray::operator/=(const double scalar); -template ND4J_EXPORT void NDArray::operator/=(const float scalar); -template ND4J_EXPORT void NDArray::operator/=(const float16 scalar); -template ND4J_EXPORT void NDArray::operator/=(const bfloat16 scalar); -template ND4J_EXPORT void NDArray::operator/=(const Nd4jLong scalar); -template ND4J_EXPORT void NDArray::operator/=(const int scalar); -template ND4J_EXPORT void NDArray::operator/=(const int16_t scalar); -template ND4J_EXPORT void NDArray::operator/=(const int8_t scalar); -template ND4J_EXPORT void NDArray::operator/=(const uint8_t scalar); -template ND4J_EXPORT void NDArray::operator/=(const bool scalar); - -//////////////////////////////////////////////////////////////////////// -// negative operator, it makes all array elements = -elements -NDArray NDArray::operator-() const & { - if (isS()) - throw std::runtime_error("NDArray::negative-: you can't use this method on String array!"); - - NDArray result(shapeInfo(), false, getContext()); - - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), sd::transform::Neg, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::operator-() && { - if (isS()) - throw std::runtime_error("NDArray::negative-: you can't use this method on String array!"); - - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), sd::transform::Neg, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); - - return std::move(*this); -} - -//////////////////////////////////////////////////////////////////////// -// mathematical multiplication of two arrays -NDArray mmul(const NDArray& left, const NDArray& right) { - if (left.isS() || right.isS()) - throw std::runtime_error("mmul friend function: you can't use this function on String array!"); - auto ptr = MmulHelper::mmul(const_cast(&left), const_cast(&right), nullptr, 1., 0.); - NDArray result(std::move(*ptr)); - delete ptr; - return result; -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::tileToShape(const std::vector& shape, NDArray& target) { - if(&target != this) { - this->tile(target); - return; - } - - std::vector thisShape(rankOf()); - for(int i = 0; i < rankOf(); ++i) - thisShape[i] = sizeAt(i); - - if(!ShapeUtils::areShapesBroadcastable(shape, thisShape)) - throw std::runtime_error("NDArray::tileToShape method: the shape of this array and input shape are not suitable for broadcast operation !"); - - const int newRank = shape.size(); - std::vector repeats(newRank); - - for(int i = 1; i <= newRank; ++i) { - if(i > rankOf()) - repeats[newRank-i] = shape[newRank - i]; - else - repeats[newRank-i] = shape[newRank - i] / thisShape[rankOf() - i]; - } - - tilei(repeats); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::tileToShape(const std::initializer_list& shape, NDArray& target) { - tileToShape(std::vector(shape), target); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::tileToShape(const Nd4jLong* shapeInfo) { - - NDArray result(const_cast(shapeInfo), false, getContext()); - tile(result); - return result; -} - -//////////////////////////////////////////////////////////////////////// -double NDArray::getTrace() const { - if (isS()) - throw std::runtime_error("NDArray::getTrace: you can't use this method on String array!"); - - int rank = rankOf(); - auto shape = shapeOf(); - int minDim = 100000000; - - Nd4jLong indices[MAX_RANK]; - for(int j = 0; j < rank; ++j) - indices[j] = 1; - - auto offset = shape::getOffset(shapeInfo(), indices); - - for(int i = 0; i < rank; ++i) - if(minDim > shape[i]) - minDim = shape[i]; - - double sum = 0.; - - for(int i = 0; i < minDim; ++i) - sum += e(i * offset); - - return sum; -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::quantize(const NDArray& array) { - - if(!array.isR()) - throw std::invalid_argument("NDArray::quantize: type of array should be from real space!"); - - auto ws = array.getContext()->getWorkspace(); - - Nd4jLong* shapeInfo = ShapeBuilders::copyShapeInfo(array.shapeInfo(), true, ws); - ArrayOptions::setPropertyBit(shapeInfo, ARRAY_QUANTIZED); - - std::shared_ptr buffer = std::make_shared(TypeCast::estimateQuantizedSize(array.lengthOf()), ArrayOptions::dataType(shapeInfo), ws); - - NDArray result(buffer, ShapeDescriptor(shapeInfo), array.getContext()); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { - - if (isS()) - throw std::runtime_error("NDArray::applyTrueBroadcast: you can't use this method on String array!"); - - if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other.isB()) || (op.s == scalar::ReverseDivide && this->isB())) - throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !"); - - if (isEmpty() || other.isEmpty()) - return; - - // if (lengthOf() == 1) { - // target.assign(this); - // target.applyPairwiseTransform(op.p, other, extraArgs); - // return; - // } - // if (other.lengthOf() == 1) { - // const_cast(this)->applyScalarArr(op.s, other, target, extraArgs); - // return; - // } - - if(checkTargetShape) { - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsTypesAndShapesSoft(target.shapeInfo(), newShapeInfo)) - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shape or type of target array is wrong !"); - } - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); - xShapeInfoH = xPack.primary(); - xShapeInfoD = xPack.special(); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); - yShapeInfoH = yPack.primary(); - yShapeInfoD = yPack.special(); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcast(getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { - - if (isS()) - throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); - - if (isEmpty() || other.isEmpty()) - return; - - // if (lengthOf() == 1) { - // NDArray temp(target._shapeInfo, dataType(), false, getContext()); - // temp.assign(this); - // temp.applyPairwiseTransform(op.p, other, target, extraArgs); - // return; - // } - // if (other.lengthOf() == 1) { - // this->applyScalarArr(op.s, other, target, extraArgs); - // return; - // } - - if(checkTargetShape) { - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsSoft(target._shapeInfo, newShapeInfo) || target.dataType() != DataType::BOOL) - throw std::runtime_error("NDArray::applyTrueBroadcast bool method: the shape or type of target array is wrong !"); - if(dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyTrueBroadcast bool method: this and other arrays must have the same type !"); - } - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); - xShapeInfoH = xPack.primary(); - xShapeInfoD = xPack.special(); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); - yShapeInfoH = yPack.primary(); - yShapeInfoD = yPack.special(); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcastBool(getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); - registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { - - if (isS()) - throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); - - if (isEmpty() || other.isEmpty()) - return; - - // if (lengthOf() == 1) { - // NDArray temp(target._shapeInfo, dataType(), false, getContext()); - // temp.assign(this); - // temp.applyPairwiseTransform(op.p, other, target, extraArgs); - // return; - // } - // if (other.lengthOf() == 1) { - // this->applyScalarArr(op.s, other, target, extraArgs); - // return; - // } - - if(checkTargetShape) { - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - if(!shape::equalsSoft(target._shapeInfo, newShapeInfo) || target.dataType() != this->dataType()) - throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !"); - if(dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !"); - } - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcastInt(getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const & { - if (isEmpty() || other.isEmpty()) { - if (isEmpty()) - return NDArray(*this); - else - return NDArray(other); - } - - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - NDArray result(newShapeInfo, true, getContext()); - - this->applyTrueBroadcast(op, other, result, false, extraArgs); - - return result; -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs) const & { - if (isEmpty() || other.isEmpty()) { - if (isEmpty()) - return NDArray(*this); - else - return NDArray(other); - } - - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(!shape::shapeEquals(newShapeInfo, other.shapeInfo())) { - - NDArray result(newShapeInfo, true, getContext()); - this->applyTrueBroadcast(op, other, result, false, extraArgs); - return std::move(result); - } - - this->applyTrueBroadcast(op, other, other, false, extraArgs); - return std::move(other); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) && { - if (isEmpty() || other.isEmpty()) { - if (isEmpty()) - return NDArray(*this); - else - return NDArray(other); - } - - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - - if(!shape::shapeEquals(newShapeInfo, shapeInfo())) { - - NDArray result(newShapeInfo, true, getContext()); - this->applyTrueBroadcast(op, other, result, false, extraArgs); - return std::move(result); - } - - this->applyTrueBroadcast(op, other, *this, false, extraArgs); - return std::move(*this); -} - -////////////////////////////////////////////////////////////////////////// -NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs) && { - if (isEmpty() || other.isEmpty()) { - if (isEmpty()) - return NDArray(*this); - else - return NDArray(other); - } - - const Nd4jLong* newShapeInfo = nullptr; - if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() - throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); - - const bool thisMove = shape::shapeEquals(newShapeInfo, shapeInfo()); - const bool otherMove = shape::shapeEquals(newShapeInfo, other.shapeInfo()); - - if(!thisMove && !otherMove) { - - NDArray result(newShapeInfo, true, getContext()); - this->applyTrueBroadcast(op, other, result, false, extraArgs); - return std::move(result); - } - - if(thisMove) { - this->applyTrueBroadcast(op, other, *this, false, extraArgs); - return std::move(*this); - } - - // otherMove - this->applyTrueBroadcast(op, other, other, false, extraArgs); - return std::move(other); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { - - if (dimensions.size() == 0) - return; - - if (isS()) - throw std::runtime_error("NDArray::applyBroadcast: you can't use this method on String array!"); - if(((op == broadcast::Divide || op == broadcast::FloorDiv || op == broadcast::FloorMod) && other.isB()) || (op == broadcast::ReverseDivide && this->isB())) - throw std::runtime_error("NDArray::applyBroadcast: you can't divide by array!"); - if(isEmpty() || other.isEmpty()) { - if(!target.isEmpty()) - throw std::runtime_error("NDArray::applyBroadcast method: when some of input arrays (or both) is empty, target array must be empty as well !"); - return; - } - - // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - // NDArray::prepareSpecialUse({&target}, {this, &other}); - // NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.special(), nullptr); - // NDArray::registerSpecialUse({&target}, {this, &other}); - // return; - // } - - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), other.shapeInfo())) - throw std::invalid_argument("NDArray::applyBroadcast method: wrong type of target array !"); - if(!target.isSameShape(this) && !target.isSameShape(other)) - throw std::invalid_argument("NDArray::applyBroadcast method: one of of two input arrays (this or other) should has the same shape as target array!"); - - std::vector copy(dimensions); - - if (dimensions.size() > 1) - std::sort(copy.begin(), copy.end()); - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcast(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(sd::broadcast::BoolOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { - - if (dimensions.size() == 0) - return; - - if (isS()) - throw std::runtime_error("NDArray::applyBroadcast BoolOps: you can't use this method on String array!"); - if(isEmpty() || other.isEmpty()) { - if(!target.isEmpty()) - throw std::runtime_error("NDArray::applyBroadcast BoolOps: when some of input arrays (or both) is empty, target array must be empty as well !"); - return; - } - - // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - // NDArray::prepareSpecialUse({&target}, {this, &other}); - // NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.special(), nullptr); - // NDArray::registerSpecialUse({&target}, {this, &other}); - // return; - // } - - if(target.dataType() != DataType::BOOL) - throw std::invalid_argument("NDArray::applyBroadcast bool method: type of target array must be BOOL!"); - if(!target.isSameShape(this) && !target.isSameShape(other)) - throw std::invalid_argument("NDArray::applyBroadcast bool method: one of of two input arrays (this or other) should has the same shape as target array!"); - if(_dataType != other._dataType) - throw std::invalid_argument("NDArray::applyBroadcast bool method: this and other arrays must have the same type !"); - - std::vector copy(dimensions); - - if (dimensions.size() > 1) - std::sort(copy.begin(), copy.end()); - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcastBool(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); - registerSpecialUse({&target}, {this, &other}); -} - - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(sd::broadcast::IntOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { - - if (dimensions.empty()) - return; - - if (!isZ()) - throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!"); - if(isEmpty() || other.isEmpty()) { - if(!target.isEmpty()) - throw std::runtime_error("NDArray::applyBroadcast IntOps: when some of input arrays (or both) is empty, target array must be empty as well !"); - return; - } - - // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { - // NDArray::prepareSpecialUse({&target}, {this, &other}); - // NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.special(), nullptr); - // NDArray::registerSpecialUse({&target}, {this, &other}); - // return; - // } - - if(target.dataType() != dataType()) - throw std::invalid_argument("NDArray::applyBroadcast int method: type of target array must be the same as input!"); - if(!target.isSameShape(this) && !target.isSameShape(other)) - throw std::invalid_argument("NDArray::applyBroadcast int method: one of of two input arrays (this or other) should has the same shape as target array!"); - if(_dataType != other._dataType) - throw std::invalid_argument("NDArray::applyBroadcast int method: this and other arrays must have the same type !"); - - std::vector copy(dimensions); - - if (dimensions.size() > 1) - std::sort(copy.begin(), copy.end()); - - Nd4jLong const* xShapeInfoH = shapeInfo(); - Nd4jLong const* yShapeInfoH = other.shapeInfo(); - Nd4jLong const* xShapeInfoD = specialShapeInfo(); - Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); - - if(!isSameShape(target)) { - auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); - xShapeInfoH = reinterpret_cast(xPack.primary()); - xShapeInfoD = reinterpret_cast(xPack.special()); - } - if(!other.isSameShape(target)) { - auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); - yShapeInfoH = reinterpret_cast(yPack.primary()); - yShapeInfoD = reinterpret_cast(yPack.special()); - } - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execBroadcastInt(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::initializer_list dimensions, const NDArray& tadArray, NDArray& target, ExtraArguments* extraArgs) { - std::vector vec(dimensions); - applyBroadcast(op, vec, tadArray, target, extraArgs); -} - -//////////////////////////////////////////////////////////////////////// -void* NDArray::operator new(size_t i) { - if (sd::memory::MemoryRegistrator::getInstance().hasWorkspaceAttached()) { - sd::memory::Workspace* ws = sd::memory::MemoryRegistrator::getInstance().getWorkspace(); - return ws->allocateBytes((Nd4jLong) i); - } - else { - auto p = malloc(i); - CHECK_ALLOC(p, "Failed to allocate new NDArray", i); - return p; - } -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::operator delete(void* p) { - if (!sd::memory::MemoryRegistrator::getInstance().hasWorkspaceAttached()) - free(p); -} - -//////////////////////////////////////////////////////////////////////// -template -std::vector NDArray::asVectorT() { - - std::vector result(this->lengthOf()); - - PRAGMA_OMP_SIMD - for (int e = 0; e < this->lengthOf(); e++) - result[e] = this->e(e); - - return result; -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LIBND4J_TYPES); - -////////////////////////////////////////////////////////////////////////// -// set new order and shape in case of suitable array length -bool NDArray::reshapei(const char order, const std::vector& cshape, const bool copyToNewBuff) { - - // check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary - if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data())) - return true; - - const bool isOutShapeEmpty = std::find(cshape.begin(), cshape.end(), 0) != cshape.end(); - - if(isEmpty() && !isOutShapeEmpty) - throw std::invalid_argument("NDArray::reshapei: can't reshape empty array to non-empty !"); - if(!isEmpty() && isOutShapeEmpty) - throw std::invalid_argument("NDArray::reshapei: can't reshape non-empty array to empty !"); - if(isEmpty() && isOutShapeEmpty) { - Nd4jLong* shapeInfoNew = ShapeBuilders::emptyShapeInfo(dataType(), order, cshape, getContext()->getWorkspace()); - setShapeInfo(shapeInfoNew); - RELEASE(shapeInfoNew, getContext()->getWorkspace()); - return true; - } - - std::vector shape(cshape); - int rank = shape.size(); - - // looking for negative in shape - - int numberNegativesOnes = 0; - - Nd4jLong* shape_ = shape.data(); - for (int i = 0; i < (int) shape.size(); i++) { - if (shape[i] < 0) { - if (numberNegativesOnes >= 1) - throw std::runtime_error("NDArray::reshapei: only one dimension can be negative at once"); - - numberNegativesOnes++; - - int shapeLength = 1; - for (int j = 0; j < (int) shape.size(); j++) - if (i != j) - shapeLength *= shape_[j]; - - Nd4jLong realShape = sd::math::nd4j_abs(lengthOf() / shapeLength); - auto thisNewShape = new Nd4jLong[shape.size()]; - - for (int j = 0; j < (int) shape.size(); j++) - if (i != j) - thisNewShape[j] = shape_[j]; - else - thisNewShape[j] = realShape; - - shape_ = thisNewShape; - } - } - - for (int e = 0; e < (int) shape.size(); e++) - shape[e] = shape_[e]; - - if (numberNegativesOnes > 0) - delete[] shape_; - - Nd4jLong arrLength = 1; - for(const auto& item : shape) - arrLength *= item; - - if(platformBuffer() == nullptr || arrLength != this->lengthOf()) { - this->printShapeInfo("Mismatched shape"); - sd::Logger::printv("Shape requested: ", shape); - nd4j_debug("Requested length in reshape: %i; Existing length: %i;\n", arrLength, this->lengthOf()); - throw std::runtime_error("NDArray::reshapei: bad input shape!"); - } - - Nd4jLong *shapeInfoNew; - ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); - - bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew); - - if (canReshape) { - setShapeInfo(shapeInfoNew); - } - else { - NDArray temp(order, shape, dataType(), getContext()); - if(copyToNewBuff) - this->applyTransform(transform::Assign, temp, nullptr); - *this = std::move(temp); - } - - RELEASE(shapeInfoNew, getContext()->getWorkspace()); - - return canReshape; -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::nullify() { - if (isEmpty()) - return; - - if (isView() || ews() != 1) - assign(0); - else - _buffer->setToZeroBuffers(); - -} - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::templatedSet(void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value) { - BUILD_SINGLE_PARTIAL_SELECTOR(dtype, templatedSet< , T>(buffer, xOfsset, value), LIBND4J_TYPES); -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value), LIBND4J_TYPES); - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ - if (isS()) - throw std::runtime_error("NDArray::applyPairwiseTransform: you can't use this method on String array!"); - if (other.lengthOf() != target.lengthOf()) - throw std::invalid_argument("NDArray::applyPairwiseTransform method - lengths of arrays are mismatched"); - if (target.dataType() != this->dataType() && target.dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyPairwiseTransform method - type of target array must be the same as type of this or other array !"); - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execPairwiseTransform(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); - NDArray::registerSpecialUse({&target}, {this, &other}); - - if (extraParams != nullptr) - synchronize("NDArray::applyPairwiseTransform"); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(sd::pairwise::BoolOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ - if (isS()) - throw std::runtime_error("NDArray::applyPairwiseTransform BoolOps: you can't use this method on String array!"); - if (other.lengthOf() != target.lengthOf()) - throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - lengths of arrays are mismatched"); - if (!target.isB()) - throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - result must have bool type"); - if (dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !"); - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); - NDArray::registerSpecialUse({&target}, {this, &other}); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(sd::pairwise::IntOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ - if (isS()) - throw std::runtime_error("NDArray::applyPairwiseTransform IntOps: you can't use this method on String array!"); - if (other.lengthOf() != target.lengthOf()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - lengths of arrays are mismatched"); - if (!target.isZ()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - result must have bool type"); - if (dataType() != other.dataType()) - throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - this and other arrays must have the same type !"); - - NDArray::prepareSpecialUse({&target}, {this, &other}); - NativeOpExecutioner::execPairwiseIntTransform(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); - NDArray::registerSpecialUse({&target}, {this, &other}); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams) { - applyPairwiseTransform(op, other, *this, extraParams); -} - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const { - auto x = reinterpret_cast(xBuffer); - const auto y = reinterpret_cast(yBuffer); - x[xOffset] = static_cast(y[yOffset]); -} -BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedDoubleAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES, LIBND4J_TYPES); - -//////////////////////////////////////////////////////////////////////// -void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector& dimensions) const { - - if (isS()) - throw std::runtime_error("NDArray::varianceAlongDimension: you can't use this method on String array!"); - - if (!target.isR()) - throw std::runtime_error("NDArray::varianceAlongDimension: target array must have FLOAT type"); - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == dimensions.size() || dimensions.empty()) - NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), biasCorrected); - else { - std::vector copy(dimensions); - auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimensions); - NativeOpExecutioner::execSummaryStats(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, dimensions.size(), packX.platformShapeInfo(), packX.platformOffsets(), biasCorrected); - synchronize("NDArray::varianceAlongDimension"); - } - - NDArray::registerSpecialUse({&target}, {this}); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const { - if (isS()) - throw std::runtime_error("NDArray::varianceAlongDimension: you can't use this method on String array!"); - - std::vector copy(dimensions); - if (copy.size() > 1) - std::sort(copy.begin(), copy.end()); - - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); - - this->varianceAlongDimension(op, result, biasCorrected, dimensions); - - return result; -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const { - return varianceAlongDimension(op, biasCorrected, std::vector(dimensions)); -} - -//////////////////////////////////////////////////////////////////////// -void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray &target, const bool biasCorrected, const std::initializer_list& dimensions) const { - varianceAlongDimension(op, target, biasCorrected, std::vector(dimensions)); -} - -//////////////////////////////////////////////////////////////////////// -// This method returns new copy of this NDArray, optionally in different order -NDArray NDArray::dup(const char newOrder) const { - - if (isEmpty()) - return NDArrayFactory::empty(dataType(), getContext()); - - char order = newOrder == 'a' ? ordering() : newOrder; - - // for now string arrays require special treatment - if (isS()) { - if (dataType() == DataType::UTF8) { - std::vector strings(lengthOf()); - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - strings[i] = std::move(this->e(i)); - } - }; - - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - return NDArray(getShapeAsVector(), strings, dataType(), getContext()); - } - if (dataType() == DataType::UTF16) { - std::vector strings(lengthOf()); - - auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - strings[i] = std::move(this->e(i)); - } - }; - - samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - - return NDArray(getShapeAsVector(), strings, dataType(), getContext()); - } - - std::vector strings(lengthOf()); auto func = PRAGMA_THREADS_FOR{ - for (auto i = start; i < stop; i++) { - strings[i] = std::move(this->e(i)); - } + for (int e = start; e < stop; e++) { + auto cdata = outData + offsets[e]; + auto end = nInputoffsets[e + 1]; + auto idata = inData + nInputoffsets[e]; + if (dtype == DataType::UTF16) { + if (dataType() == DataType::UTF8) { + unicode::utf8to16(idata, outData, end); + } + else { + unicode::utf32to16(idata, outData, (end / sizeof(char32_t))); + } + } + else if (dtype == DataType::UTF32) { + if (dataType() == DataType::UTF8) { + unicode::utf8to32(idata, cdata, end); + } + else { + unicode::utf16to32(idata, outData, (end / sizeof(char16_t))); + } + } + else { + if (dataType() == DataType::UTF16) { + unicode::utf16to8(idata, outData, (end / sizeof(char16_t))); + } + else { + unicode::utf32to8(idata, outData, (end / sizeof(char32_t))); + } + } + } }; samediff::Threads::parallel_for(func, 0, lengthOf(), 1); - return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + registerPrimaryUse({ &res }, { this }); + + return res; + } + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT NDArray NDArray::asS, () const, LIBND4J_STRINGTYPES); + +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::asT(DataType dtype) const { + + if (isS() && !DataTypeUtils::isS(dtype)) + throw std::runtime_error("NDArray::asT: you can't use this method on String array with not string DataType!"); + + if (!isS() && DataTypeUtils::isS(dtype)) + throw std::runtime_error("NDArray::asT: you can't use this method on not String array with string DataType!"); + + if (isS()){ + BUILD_SINGLE_SELECTOR(dtype, return asS, (), LIBND4J_STRINGTYPES); + } else { + BUILD_SINGLE_SELECTOR(dtype, return asT, (), LIBND4J_TYPES); + } + + return NDArray(); } - NDArray result(order, isScalar() ? std::vector({0}) : getShapeAsVector(), dataType(), getContext()); - result.assign(*this); +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::cast(DataType dtype) const { - return result; -} + if (isS() && !DataTypeUtils::isS(dtype)) + throw std::runtime_error("NDArray::cast: you can't use this method on String array with not string DataType!"); + + if (!isS() && DataTypeUtils::isS(dtype)) + throw std::runtime_error("NDArray::cast: you can't use this method on not String array with string DataType!"); + + return this->asT(dtype); + } + +//////////////////////////////////////////////////////////////////////// + void NDArray::cast(NDArray& target, DataType dtype) { + if (isS()) + throw std::runtime_error("NDArray::cast: you can't use this method on String array!"); + // TODO: to be implemented properly + target.assign(this); + } + +//////////////////////////////////////////////////////////////////////// + void NDArray::operator+=(const NDArray& other) { + + if (isS()) + throw std::runtime_error("NDArray::operator+=: you can't use this method on String array!"); + if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) + throw sd::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType()); + + if (this->lengthOf() != 1 && other.lengthOf() == 1) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar(getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } + else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } + else{ + const Nd4jLong *bShape = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) + throw std::invalid_argument("NDArray::operator+=: the shapes of this and other arrays are not suitable for broadcast operation !"); + + if(shape::equalsTypesAndShapesSoft(shapeInfo(), bShape)) { + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), other, *this, false); + } + else { + NDArray result(bShape, true, getContext()); + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), other, result, false); + *this = std::move(result); // move assignment operator, zero cost copy + } + } + } + +//////////////////////////////////////////////////////////////////////// + void NDArray::operator-=(const NDArray& other) { + if (isS()) + throw std::runtime_error("NDArray::operator-=: you can't use this method on String array!"); + + if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) + throw sd::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType()); + + if (lengthOf() != 1 && other.lengthOf() == 1) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar(getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } + else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } + else{ + const Nd4jLong *bShape = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) + throw std::invalid_argument("NDArray::operator-=: the shapes of this and other arrays are not suitable for broadcast operation !"); + + if(shape::equalsTypesAndShapesSoft(shapeInfo(), bShape)) { + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), other, *this, false); + } + else { + NDArray result(bShape, true, getContext()); + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), other, result, false); + *this = std::move(result); // move assignment operator, zero cost copy + } + } + } + +//////////////////////////////////////////////////////////////////////// + void NDArray::operator*=(const NDArray& other) { + if (isS()) + throw std::runtime_error("NDArray::operator*=: you can't use this method on String array!"); + if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) + throw sd::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType()); + + if (lengthOf() != 1 && other.lengthOf() == 1) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar(getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } + else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } + else{ + const Nd4jLong *bShape = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) + throw std::invalid_argument("NDArray::operator*=: the shapes of this and other arrays are not suitable for broadcast operation !"); + + if(shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), other, *this, false); + } + else { + NDArray result(bShape, true, getContext()); + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), other, result, false); + *this = std::move(result); // move assignment operator, zero cost copy + } + } + } + +//////////////////////////////////////////////////////////////////////// + void NDArray::operator/=(const NDArray& other) { + if (isS() || other.isS()) + throw std::runtime_error("NDArray::operator/=: you can't use this method on String array!"); + if (other.isB()) + throw std::runtime_error("NDArray::operator/=: you can't divide by bool array!"); + + if (!Environment::getInstance().isExperimentalBuild() && this->dataType() != other.dataType()) { + throw sd::datatype_exception::build("NDArray operator/=: Cannot divide different types", this->dataType(), other.dataType()); + } + + if (lengthOf() != 1 && other.lengthOf() == 1) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar(getContext(), sd::scalar::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } + else if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform(getContext(), sd::pairwise::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } + else{ + const Nd4jLong *bShape = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, bShape, getContext()->getWorkspace())) + throw std::invalid_argument("NDArray::operator/=: the shapes of this and other arrays are not suitable for broadcast operation !"); + + if(shape::equalsTypesAndShapesSoft(_shapeInfo, bShape)) { + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), other, *this, false); + } + else { + NDArray result(bShape, true, getContext()); + this->applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), other, result, false); + *this = std::move(result); // move assignment operator, zero cost copy + } + } + } + +//////////////////////////////////////////////////////////////////////// + template + void NDArray::operator+=(const T value) { + if (isS()) + throw std::runtime_error("NDArray::operator+=: you can't use this method on String array!"); + + auto other = NDArrayFactory::create(this->dataType(), value, getContext()); + + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar(getContext(), sd::scalar::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } + template ND4J_EXPORT void NDArray::operator+=(const double value); + template ND4J_EXPORT void NDArray::operator+=(const float value); + template ND4J_EXPORT void NDArray::operator+=(const float16 value); + template ND4J_EXPORT void NDArray::operator+=(const bfloat16 value); + template ND4J_EXPORT void NDArray::operator+=(const Nd4jLong value); + template ND4J_EXPORT void NDArray::operator+=(const int value); + template ND4J_EXPORT void NDArray::operator+=(const bool value); + +//////////////////////////////////////////////////////////////////////// + template + void NDArray::operator-=(const T value) { + if (isS()) + throw std::runtime_error("NDArray::operator-=: you can't use this method on String array!"); + + auto other = NDArrayFactory::create(dataType(), value, getContext()); + + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar(getContext(), sd::scalar::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } + template ND4J_EXPORT void NDArray::operator-=(const double value); + template ND4J_EXPORT void NDArray::operator-=(const float value); + template ND4J_EXPORT void NDArray::operator-=(const float16 value); + template ND4J_EXPORT void NDArray::operator-=(const bfloat16 value); + template ND4J_EXPORT void NDArray::operator-=(const Nd4jLong value); + template ND4J_EXPORT void NDArray::operator-=(const int value); + template ND4J_EXPORT void NDArray::operator-=(const bool value); + +//////////////////////////////////////////////////////////////////////// + template + void NDArray::operator*=(const T scalar) { + if (isS()) + throw std::runtime_error("NDArray::operator*=: you can't use this method on String array!"); + + auto other = NDArrayFactory::create(this->dataType(), scalar, getContext()); + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar(getContext(), sd::scalar::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } + template ND4J_EXPORT void NDArray::operator*=(const double scalar); + template ND4J_EXPORT void NDArray::operator*=(const float scalar); + template ND4J_EXPORT void NDArray::operator*=(const float16 scalar); + template ND4J_EXPORT void NDArray::operator*=(const bfloat16 scalar); + template ND4J_EXPORT void NDArray::operator*=(const Nd4jLong scalar); + template ND4J_EXPORT void NDArray::operator*=(const int scalar); + template ND4J_EXPORT void NDArray::operator*=(const int16_t scalar); + template ND4J_EXPORT void NDArray::operator*=(const int8_t scalar); + template ND4J_EXPORT void NDArray::operator*=(const uint8_t scalar); + template ND4J_EXPORT void NDArray::operator*=(const bool scalar); + +//////////////////////////////////////////////////////////////////////// + template + void NDArray::operator/=(const T scalar) { + if (isS()) + throw std::runtime_error("NDArray::operator/=: you can't use this method on String array!"); + + auto other = NDArrayFactory::create(this->dataType(), scalar, getContext()); + NDArray::prepareSpecialUse({this}, {this, &other}); + NativeOpExecutioner::execScalar(getContext(), sd::scalar::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({this}, {this, &other}); + } + template ND4J_EXPORT void NDArray::operator/=(const double scalar); + template ND4J_EXPORT void NDArray::operator/=(const float scalar); + template ND4J_EXPORT void NDArray::operator/=(const float16 scalar); + template ND4J_EXPORT void NDArray::operator/=(const bfloat16 scalar); + template ND4J_EXPORT void NDArray::operator/=(const Nd4jLong scalar); + template ND4J_EXPORT void NDArray::operator/=(const int scalar); + template ND4J_EXPORT void NDArray::operator/=(const int16_t scalar); + template ND4J_EXPORT void NDArray::operator/=(const int8_t scalar); + template ND4J_EXPORT void NDArray::operator/=(const uint8_t scalar); + template ND4J_EXPORT void NDArray::operator/=(const bool scalar); + +//////////////////////////////////////////////////////////////////////// +// negative operator, it makes all array elements = -elements + NDArray NDArray::operator-() const & { + if (isS()) + throw std::runtime_error("NDArray::negative-: you can't use this method on String array!"); + + NDArray result(shapeInfo(), false, getContext()); + + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformSame(getContext(), sd::transform::Neg, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&result}, {this}); + + return result; + } + +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::operator-() && { + if (isS()) + throw std::runtime_error("NDArray::negative-: you can't use this method on String array!"); + + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformSame(getContext(), sd::transform::Neg, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); + + return std::move(*this); + } + +//////////////////////////////////////////////////////////////////////// +// mathematical multiplication of two arrays + NDArray mmul(const NDArray& left, const NDArray& right) { + if (left.isS() || right.isS()) + throw std::runtime_error("mmul friend function: you can't use this function on String array!"); + auto ptr = MmulHelper::mmul(const_cast(&left), const_cast(&right), nullptr, 1., 0.); + NDArray result(std::move(*ptr)); + delete ptr; + return result; + } + +//////////////////////////////////////////////////////////////////////// + void NDArray::tileToShape(const std::vector& shape, NDArray& target) { + if(&target != this) { + this->tile(target); + return; + } + + std::vector thisShape(rankOf()); + for(int i = 0; i < rankOf(); ++i) + thisShape[i] = sizeAt(i); + + if(!ShapeUtils::areShapesBroadcastable(shape, thisShape)) + throw std::runtime_error("NDArray::tileToShape method: the shape of this array and input shape are not suitable for broadcast operation !"); + + const int newRank = shape.size(); + std::vector repeats(newRank); + + for(int i = 1; i <= newRank; ++i) { + if(i > rankOf()) + repeats[newRank-i] = shape[newRank - i]; + else + repeats[newRank-i] = shape[newRank - i] / thisShape[rankOf() - i]; + } + + tilei(repeats); + } + +//////////////////////////////////////////////////////////////////////// + void NDArray::tileToShape(const std::initializer_list& shape, NDArray& target) { + tileToShape(std::vector(shape), target); + } + +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::tileToShape(const Nd4jLong* shapeInfo) { + + NDArray result(const_cast(shapeInfo), false, getContext()); + tile(result); + return result; + } + +//////////////////////////////////////////////////////////////////////// + double NDArray::getTrace() const { + if (isS()) + throw std::runtime_error("NDArray::getTrace: you can't use this method on String array!"); + + int rank = rankOf(); + auto shape = shapeOf(); + int minDim = 100000000; + + Nd4jLong indices[MAX_RANK]; + for(int j = 0; j < rank; ++j) + indices[j] = 1; + + auto offset = shape::getOffset(shapeInfo(), indices); + + for(int i = 0; i < rank; ++i) + if(minDim > shape[i]) + minDim = shape[i]; + + double sum = 0.; + + for(int i = 0; i < minDim; ++i) + sum += e(i * offset); + + return sum; + } + +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::quantize(const NDArray& array) { + + if(!array.isR()) + throw std::invalid_argument("NDArray::quantize: type of array should be from real space!"); + + auto ws = array.getContext()->getWorkspace(); + + Nd4jLong* shapeInfo = ShapeBuilders::copyShapeInfo(array.shapeInfo(), true, ws); + ArrayOptions::setPropertyBit(shapeInfo, ARRAY_QUANTIZED); + + std::shared_ptr buffer = std::make_shared(TypeCast::estimateQuantizedSize(array.lengthOf()), ArrayOptions::dataType(shapeInfo), ws); + + NDArray result(buffer, ShapeDescriptor(shapeInfo), array.getContext()); + + return result; + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { + + if (isS()) + throw std::runtime_error("NDArray::applyTrueBroadcast: you can't use this method on String array!"); + + if(((op.s == scalar::Divide || op.s == scalar::FloorDiv || op.s == scalar::FloorMod) && other.isB()) || (op.s == scalar::ReverseDivide && this->isB())) + throw std::runtime_error("NDArray::applyTrueBroadcast method: you can't divide by bool array !"); + + if (isEmpty() || other.isEmpty()) + return; + + // if (lengthOf() == 1) { + // target.assign(this); + // target.applyPairwiseTransform(op.p, other, extraArgs); + // return; + // } + // if (other.lengthOf() == 1) { + // const_cast(this)->applyScalarArr(op.s, other, target, extraArgs); + // return; + // } + + if(checkTargetShape) { + const Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + if(!shape::equalsTypesAndShapesSoft(target.shapeInfo(), newShapeInfo)) + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shape or type of target array is wrong !"); + } + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if(!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); + xShapeInfoH = xPack.primary(); + xShapeInfoD = xPack.special(); + } + if(!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); + yShapeInfoH = yPack.primary(); + yShapeInfoD = yPack.special(); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcast(getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyTrueBroadcast(sd::BroadcastBoolOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { + + if (isS()) + throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); + + if (isEmpty() || other.isEmpty()) + return; + + // if (lengthOf() == 1) { + // NDArray temp(target._shapeInfo, dataType(), false, getContext()); + // temp.assign(this); + // temp.applyPairwiseTransform(op.p, other, target, extraArgs); + // return; + // } + // if (other.lengthOf() == 1) { + // this->applyScalarArr(op.s, other, target, extraArgs); + // return; + // } + + if(checkTargetShape) { + const Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + if(!shape::equalsSoft(target._shapeInfo, newShapeInfo) || target.dataType() != DataType::BOOL) + throw std::runtime_error("NDArray::applyTrueBroadcast bool method: the shape or type of target array is wrong !"); + if(dataType() != other.dataType()) + throw std::invalid_argument("NDArray::applyTrueBroadcast bool method: this and other arrays must have the same type !"); + } + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if(!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); + xShapeInfoH = xPack.primary(); + xShapeInfoD = xPack.special(); + } + if(!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); + yShapeInfoH = yPack.primary(); + yShapeInfoD = yPack.special(); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastBool(getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); + registerSpecialUse({&target}, {this, &other}); + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyTrueBroadcast(sd::BroadcastIntOpsTuple op, const NDArray& other, NDArray& target, const bool checkTargetShape, ExtraArguments *extraArgs) const { + + if (isS()) + throw std::runtime_error("NDArray::applyTrueBroadcast bool: you can't use this method on String array!"); + + if (isEmpty() || other.isEmpty()) + return; + + // if (lengthOf() == 1) { + // NDArray temp(target._shapeInfo, dataType(), false, getContext()); + // temp.assign(this); + // temp.applyPairwiseTransform(op.p, other, target, extraArgs); + // return; + // } + // if (other.lengthOf() == 1) { + // this->applyScalarArr(op.s, other, target, extraArgs); + // return; + // } + + if(checkTargetShape) { + const Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, false, newShapeInfo, getContext()->getWorkspace())) // the rank of target array must be equal to max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + if(!shape::equalsSoft(target._shapeInfo, newShapeInfo) || target.dataType() != this->dataType()) + throw std::runtime_error("NDArray::applyTrueBroadcast int method: the shape or type of target array is wrong !"); + if(dataType() != other.dataType()) + throw std::invalid_argument("NDArray::applyTrueBroadcast int method: this and other arrays must have the same type !"); + } + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if(!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace()); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if(!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace()); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastInt(getContext(), op.b, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); + } + +////////////////////////////////////////////////////////////////////////// + NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) const & { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + const Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + NDArray result(newShapeInfo, true, getContext()); + + this->applyTrueBroadcast(op, other, result, false, extraArgs); + + return result; + } + +////////////////////////////////////////////////////////////////////////// + NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs) const & { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + const Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + + if(!shape::shapeEquals(newShapeInfo, other.shapeInfo())) { + + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } + + this->applyTrueBroadcast(op, other, other, false, extraArgs); + return std::move(other); + } + +////////////////////////////////////////////////////////////////////////// + NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray& other, ExtraArguments *extraArgs) && { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + const Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + + if(!shape::shapeEquals(newShapeInfo, shapeInfo())) { + + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } + + this->applyTrueBroadcast(op, other, *this, false, extraArgs); + return std::move(*this); + } + +////////////////////////////////////////////////////////////////////////// + NDArray NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, NDArray&& other, ExtraArguments *extraArgs) && { + if (isEmpty() || other.isEmpty()) { + if (isEmpty()) + return NDArray(*this); + else + return NDArray(other); + } + + const Nd4jLong* newShapeInfo = nullptr; + if(!ShapeUtils::evalBroadcastShapeInfo(*this, other, true, newShapeInfo, getContext()->getWorkspace())) // the rank of new array = max->rankOf)() + throw std::runtime_error("NDArray::applyTrueBroadcast method: the shapes of this and other arrays are not suitable for broadcast operation !"); + + const bool thisMove = shape::shapeEquals(newShapeInfo, shapeInfo()); + const bool otherMove = shape::shapeEquals(newShapeInfo, other.shapeInfo()); + + if(!thisMove && !otherMove) { + + NDArray result(newShapeInfo, true, getContext()); + this->applyTrueBroadcast(op, other, result, false, extraArgs); + return std::move(result); + } + + if(thisMove) { + this->applyTrueBroadcast(op, other, *this, false, extraArgs); + return std::move(*this); + } + + // otherMove + this->applyTrueBroadcast(op, other, other, false, extraArgs); + return std::move(other); + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { + + if (dimensions.size() == 0) + return; + + if (isS()) + throw std::runtime_error("NDArray::applyBroadcast: you can't use this method on String array!"); + if(((op == broadcast::Divide || op == broadcast::FloorDiv || op == broadcast::FloorMod) && other.isB()) || (op == broadcast::ReverseDivide && this->isB())) + throw std::runtime_error("NDArray::applyBroadcast: you can't divide by array!"); + if(isEmpty() || other.isEmpty()) { + if(!target.isEmpty()) + throw std::runtime_error("NDArray::applyBroadcast method: when some of input arrays (or both) is empty, target array must be empty as well !"); + return; + } + + // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + // NDArray::prepareSpecialUse({&target}, {this, &other}); + // NativeOpExecutioner::execPairwiseTransform(getContext(), fromBroadcastToPairwise(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.special(), nullptr); + // NDArray::registerSpecialUse({&target}, {this, &other}); + // return; + // } + + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), other.shapeInfo())) + throw std::invalid_argument("NDArray::applyBroadcast method: wrong type of target array !"); + if(!target.isSameShape(this) && !target.isSameShape(other)) + throw std::invalid_argument("NDArray::applyBroadcast method: one of of two input arrays (this or other) should has the same shape as target array!"); + + std::vector copy(dimensions); + + if (dimensions.size() > 1) + std::sort(copy.begin(), copy.end()); + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if(!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if(!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcast(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyBroadcast(sd::broadcast::BoolOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { + + if (dimensions.size() == 0) + return; + + if (isS()) + throw std::runtime_error("NDArray::applyBroadcast BoolOps: you can't use this method on String array!"); + if(isEmpty() || other.isEmpty()) { + if(!target.isEmpty()) + throw std::runtime_error("NDArray::applyBroadcast BoolOps: when some of input arrays (or both) is empty, target array must be empty as well !"); + return; + } + + // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + // NDArray::prepareSpecialUse({&target}, {this, &other}); + // NativeOpExecutioner::execPairwiseBoolTransform(getContext(), fromBroadcastToPairwiseBool(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.special(), nullptr); + // NDArray::registerSpecialUse({&target}, {this, &other}); + // return; + // } + + if(target.dataType() != DataType::BOOL) + throw std::invalid_argument("NDArray::applyBroadcast bool method: type of target array must be BOOL!"); + if(!target.isSameShape(this) && !target.isSameShape(other)) + throw std::invalid_argument("NDArray::applyBroadcast bool method: one of of two input arrays (this or other) should has the same shape as target array!"); + if(_dataType != other._dataType) + throw std::invalid_argument("NDArray::applyBroadcast bool method: this and other arrays must have the same type !"); + + std::vector copy(dimensions); + + if (dimensions.size() > 1) + std::sort(copy.begin(), copy.end()); + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if(!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if(!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastBool(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr); + registerSpecialUse({&target}, {this, &other}); + } + + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyBroadcast(sd::broadcast::IntOps op, const std::vector& dimensions, const NDArray& other, NDArray& target, ExtraArguments* extraArgs) { + + if (dimensions.empty()) + return; + + if (!isZ()) + throw std::runtime_error("NDArray::applyBroadcast IntOps: you can't use this method on non-Integer array!"); + if(isEmpty() || other.isEmpty()) { + if(!target.isEmpty()) + throw std::runtime_error("NDArray::applyBroadcast IntOps: when some of input arrays (or both) is empty, target array must be empty as well !"); + return; + } + + // if (other.lengthOf() == lengthOf() && this->rankOf() == other.rankOf()) { + // NDArray::prepareSpecialUse({&target}, {this, &other}); + // NativeOpExecutioner::execPairwiseIntTransform(getContext(), fromBroadcastToPairwiseInt(op), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.special(), nullptr); + // NDArray::registerSpecialUse({&target}, {this, &other}); + // return; + // } + + if(target.dataType() != dataType()) + throw std::invalid_argument("NDArray::applyBroadcast int method: type of target array must be the same as input!"); + if(!target.isSameShape(this) && !target.isSameShape(other)) + throw std::invalid_argument("NDArray::applyBroadcast int method: one of of two input arrays (this or other) should has the same shape as target array!"); + if(_dataType != other._dataType) + throw std::invalid_argument("NDArray::applyBroadcast int method: this and other arrays must have the same type !"); + + std::vector copy(dimensions); + + if (dimensions.size() > 1) + std::sort(copy.begin(), copy.end()); + + Nd4jLong const* xShapeInfoH = shapeInfo(); + Nd4jLong const* yShapeInfoH = other.shapeInfo(); + Nd4jLong const* xShapeInfoD = specialShapeInfo(); + Nd4jLong const* yShapeInfoD = other.specialShapeInfo(); + + if(!isSameShape(target)) { + auto xPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), shapeInfo(), getContext()->getWorkspace(), copy); + xShapeInfoH = reinterpret_cast(xPack.primary()); + xShapeInfoD = reinterpret_cast(xPack.special()); + } + if(!other.isSameShape(target)) { + auto yPack = ConstantShapeHelper::getInstance().createShapeInfoWithUnitiesForBroadcast(target.shapeInfo(), other.shapeInfo(), other.getContext()->getWorkspace(), copy); + yShapeInfoH = reinterpret_cast(yPack.primary()); + yShapeInfoD = reinterpret_cast(yPack.special()); + } + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execBroadcastInt(getContext(), op, buffer(), xShapeInfoH, specialBuffer(), xShapeInfoD, other.buffer(), yShapeInfoH, other.specialBuffer(), yShapeInfoD, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + registerSpecialUse({&target}, {this, &other}); + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyBroadcast(sd::broadcast::Ops op, const std::initializer_list dimensions, const NDArray& tadArray, NDArray& target, ExtraArguments* extraArgs) { + std::vector vec(dimensions); + applyBroadcast(op, vec, tadArray, target, extraArgs); + } + +//////////////////////////////////////////////////////////////////////// + void* NDArray::operator new(size_t i) { + if (sd::memory::MemoryRegistrator::getInstance().hasWorkspaceAttached()) { + sd::memory::Workspace* ws = sd::memory::MemoryRegistrator::getInstance().getWorkspace(); + return ws->allocateBytes((Nd4jLong) i); + } + else { + auto p = malloc(i); + CHECK_ALLOC(p, "Failed to allocate new NDArray", i); + return p; + } + } + +//////////////////////////////////////////////////////////////////////// + void NDArray::operator delete(void* p) { + if (!sd::memory::MemoryRegistrator::getInstance().hasWorkspaceAttached()) + free(p); + } + +//////////////////////////////////////////////////////////////////////// + template + std::vector NDArray::asVectorT() { + + std::vector result(this->lengthOf()); + + PRAGMA_OMP_SIMD + for (int e = 0; e < this->lengthOf(); e++) + result[e] = this->e(e); + + return result; + } + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::asVectorT(), LIBND4J_TYPES); + +////////////////////////////////////////////////////////////////////////// +// set new order and shape in case of suitable array length + bool NDArray::reshapei(const char order, const std::vector& cshape, const bool copyToNewBuff) { + + // check firstly whether cshape is identical to shape of array, if yes then reshape is unnecessary + if(order == ordering() && shape::shapeEquals(rankOf(), shapeOf(), cshape.size(), cshape.data())) + return true; + + const bool isOutShapeEmpty = std::find(cshape.begin(), cshape.end(), 0) != cshape.end(); + + if(isEmpty() && !isOutShapeEmpty) + throw std::invalid_argument("NDArray::reshapei: can't reshape empty array to non-empty !"); + if(!isEmpty() && isOutShapeEmpty) + throw std::invalid_argument("NDArray::reshapei: can't reshape non-empty array to empty !"); + if(isEmpty() && isOutShapeEmpty) { + Nd4jLong* shapeInfoNew = ShapeBuilders::emptyShapeInfo(dataType(), order, cshape, getContext()->getWorkspace()); + setShapeInfo(shapeInfoNew); + RELEASE(shapeInfoNew, getContext()->getWorkspace()); + return true; + } + + std::vector shape(cshape); + int rank = shape.size(); + + // looking for negative in shape + + int numberNegativesOnes = 0; + + Nd4jLong* shape_ = shape.data(); + for (int i = 0; i < (int) shape.size(); i++) { + if (shape[i] < 0) { + if (numberNegativesOnes >= 1) + throw std::runtime_error("NDArray::reshapei: only one dimension can be negative at once"); + + numberNegativesOnes++; + + int shapeLength = 1; + for (int j = 0; j < (int) shape.size(); j++) + if (i != j) + shapeLength *= shape_[j]; + + Nd4jLong realShape = sd::math::nd4j_abs(lengthOf() / shapeLength); + auto thisNewShape = new Nd4jLong[shape.size()]; + + for (int j = 0; j < (int) shape.size(); j++) + if (i != j) + thisNewShape[j] = shape_[j]; + else + thisNewShape[j] = realShape; + + shape_ = thisNewShape; + } + } + + for (int e = 0; e < (int) shape.size(); e++) + shape[e] = shape_[e]; + + if (numberNegativesOnes > 0) + delete[] shape_; + + Nd4jLong arrLength = 1; + for(const auto& item : shape) + arrLength *= item; + + if(platformBuffer() == nullptr || arrLength != this->lengthOf()) { + this->printShapeInfo("Mismatched shape"); + sd::Logger::printv("Shape requested: ", shape); + nd4j_debug("Requested length in reshape: %i; Existing length: %i;\n", arrLength, this->lengthOf()); + throw std::runtime_error("NDArray::reshapei: bad input shape!"); + } + + Nd4jLong *shapeInfoNew; + ALLOCATE(shapeInfoNew, getContext()->getWorkspace(), shape::shapeInfoLength(rank), Nd4jLong); + + bool canReshape = shape::reshapeC(shapeInfo(), order, shape.size(), shape.data(), shapeInfoNew); + + if (canReshape) { + setShapeInfo(shapeInfoNew); + } + else { + NDArray temp(order, shape, dataType(), getContext()); + if(copyToNewBuff) + this->applyTransform(transform::Assign, temp, nullptr); + *this = std::move(temp); + } + + RELEASE(shapeInfoNew, getContext()->getWorkspace()); + + return canReshape; + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::nullify() { + if (isEmpty()) + return; + + if (isView() || ews() != 1) + assign(0); + else + _buffer->setToZeroBuffers(); + + } + +//////////////////////////////////////////////////////////////////////// + template + void NDArray::templatedSet(void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value) { + BUILD_SINGLE_PARTIAL_SELECTOR(dtype, templatedSet< , T>(buffer, xOfsset, value), LIBND4J_TYPES); + } + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedSet, (void *buffer, const Nd4jLong xOfsset, sd::DataType dtype, const void *value), LIBND4J_TYPES); + +//////////////////////////////////////////////////////////////////////// + void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ + if (isS()) + throw std::runtime_error("NDArray::applyPairwiseTransform: you can't use this method on String array!"); + if (other.lengthOf() != target.lengthOf()) + throw std::invalid_argument("NDArray::applyPairwiseTransform method - lengths of arrays are mismatched"); + if (target.dataType() != this->dataType() && target.dataType() != other.dataType()) + throw std::invalid_argument("NDArray::applyPairwiseTransform method - type of target array must be the same as type of this or other array !"); + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseTransform(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); + + if (extraParams != nullptr) + synchronize("NDArray::applyPairwiseTransform"); + } + +//////////////////////////////////////////////////////////////////////// + void NDArray::applyPairwiseTransform(sd::pairwise::BoolOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ + if (isS()) + throw std::runtime_error("NDArray::applyPairwiseTransform BoolOps: you can't use this method on String array!"); + if (other.lengthOf() != target.lengthOf()) + throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - lengths of arrays are mismatched"); + if (!target.isB()) + throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - result must have bool type"); + if (dataType() != other.dataType()) + throw std::invalid_argument("NDArray::applyPairwiseTransform BoolOps method - this and other arrays must have the same type !"); + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseBoolTransform(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); + } + +//////////////////////////////////////////////////////////////////////// + void NDArray::applyPairwiseTransform(sd::pairwise::IntOps op, const NDArray& other, NDArray& target, ExtraArguments *extraParams) const{ + if (isS()) + throw std::runtime_error("NDArray::applyPairwiseTransform IntOps: you can't use this method on String array!"); + if (other.lengthOf() != target.lengthOf()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - lengths of arrays are mismatched"); + if (!target.isZ()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - result must have bool type"); + if (dataType() != other.dataType()) + throw std::invalid_argument("NDArray::applyPairwiseTransform IntOps method - this and other arrays must have the same type !"); + + NDArray::prepareSpecialUse({&target}, {this, &other}); + NativeOpExecutioner::execPairwiseIntTransform(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr); + NDArray::registerSpecialUse({&target}, {this, &other}); + } + +////////////////////////////////////////////////////////////////////////// + void NDArray::applyPairwiseTransform(sd::pairwise::Ops op, const NDArray& other, ExtraArguments *extraParams) { + applyPairwiseTransform(op, other, *this, extraParams); + } + +//////////////////////////////////////////////////////////////////////// + template + void NDArray::templatedDoubleAssign(void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const { + auto x = reinterpret_cast(xBuffer); + const auto y = reinterpret_cast(yBuffer); + x[xOffset] = static_cast(y[yOffset]); + } + BUILD_DOUBLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedDoubleAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES, LIBND4J_TYPES); + +//////////////////////////////////////////////////////////////////////// + void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray& target, const bool biasCorrected, const std::vector& dimensions) const { + + if (isS()) + throw std::runtime_error("NDArray::varianceAlongDimension: you can't use this method on String array!"); + + if (!target.isR()) + throw std::runtime_error("NDArray::varianceAlongDimension: target array must have FLOAT type"); + + NDArray::prepareSpecialUse({&target}, {this}); + + if(rankOf() == dimensions.size() || dimensions.empty()) + NativeOpExecutioner::execSummaryStatsScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), biasCorrected); + else { + std::vector copy(dimensions); + auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, dimensions.size(), packX.platformShapeInfo(), packX.platformOffsets(), biasCorrected); + synchronize("NDArray::varianceAlongDimension"); + } + + NDArray::registerSpecialUse({&target}, {this}); + } + +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::vector& dimensions) const { + if (isS()) + throw std::runtime_error("NDArray::varianceAlongDimension: you can't use this method on String array!"); + + std::vector copy(dimensions); + if (copy.size() > 1) + std::sort(copy.begin(), copy.end()); + + auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, false, getContext()->getWorkspace()); + NDArray result(newShape, true, getContext()); + + this->varianceAlongDimension(op, result, biasCorrected, dimensions); + + return result; + } + +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::varianceAlongDimension(sd::variance::Ops op, const bool biasCorrected, const std::initializer_list& dimensions) const { + return varianceAlongDimension(op, biasCorrected, std::vector(dimensions)); + } + +//////////////////////////////////////////////////////////////////////// + void NDArray::varianceAlongDimension(sd::variance::Ops op, NDArray &target, const bool biasCorrected, const std::initializer_list& dimensions) const { + varianceAlongDimension(op, target, biasCorrected, std::vector(dimensions)); + } + +//////////////////////////////////////////////////////////////////////// +// This method returns new copy of this NDArray, optionally in different order + NDArray NDArray::dup(const char newOrder) const { + + if (isEmpty()) + return NDArrayFactory::empty(dataType(), getContext()); + + char order = newOrder == 'a' ? ordering() : newOrder; + + // for now string arrays require special treatment + if (isS()) { + if (dataType() == DataType::UTF8) { + std::vector strings(lengthOf()); + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + strings[i] = std::move(this->e(i)); + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + } + if (dataType() == DataType::UTF16) { + std::vector strings(lengthOf()); + + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + strings[i] = std::move(this->e(i)); + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + } + + std::vector strings(lengthOf()); + auto func = PRAGMA_THREADS_FOR{ + for (auto i = start; i < stop; i++) { + strings[i] = std::move(this->e(i)); + } + }; + + samediff::Threads::parallel_for(func, 0, lengthOf(), 1); + + return NDArray(getShapeAsVector(), strings, dataType(), getContext()); + } + + NDArray result(order, isScalar() ? std::vector({0}) : getShapeAsVector(), dataType(), getContext()); + result.assign(*this); + + return result; + } //////////////////////////////////////////////////////////////////////// // This method returns true if two arrays are equal, with custom or default Eps value of 1e-5, false otherwise -bool NDArray::equalsTo(const NDArray *other, double eps) const { + bool NDArray::equalsTo(const NDArray *other, double eps) const { - if (dataType() != other->dataType() || lengthOf() != other->lengthOf()) - return false; - - // we need to be able to compare [1, len] to [len] - if ((rankOf() == 1 && other->rankOf() == 2) || (rankOf() == 2 && other->rankOf() == 1)) { - // FIXME: do something here? - } else if (!shape::equalsSoft(shapeInfo(), other->shapeInfo())) - return false; - - if (isS()) { - // string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length - - if (dataType() == DataType::UTF8) { - for (Nd4jLong e = 0; e < this->lengthOf(); e++) { - auto s1 = this->e(e); - auto s2 = other->e(e); - - if (s1 != s2) - return false; - } - } - else if (dataType() == DataType::UTF16) { - for (Nd4jLong e = 0; e < this->lengthOf(); e++) { - auto s1 = this->e(e); - auto s2 = other->e(e); - - if (s1 != s2) - return false; - } - } - else { - for (Nd4jLong e = 0; e < this->lengthOf(); e++) { - auto s1 = this->e(e); - auto s2 = other->e(e); - - if (s1 != s2) - return false; - } - } - - return true; - } else { - // regular numeric types - NDArray tmp(sd::DataType::FLOAT32, getContext()); // scalar = 0 - - ExtraArguments extras({0.0, 0.0, eps}); - - NDArray::prepareSpecialUse({&tmp}, {this, other}); - NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, buffer(), shapeInfo(), - specialBuffer(), specialShapeInfo(), - extras.argumentsAsT(DataType::FLOAT32), other->buffer(), - other->shapeInfo(), other->specialBuffer(), - other->specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), - tmp.specialBuffer(), tmp.specialShapeInfo()); - NDArray::registerSpecialUse({&tmp}, {this, other}); - - synchronize("NDArray::equalsTo"); - - if (tmp.e(0) != 0) + if (dataType() != other->dataType() || lengthOf() != other->lengthOf()) return false; - return true; + // we need to be able to compare [1, len] to [len] + if ((rankOf() == 1 && other->rankOf() == 2) || (rankOf() == 2 && other->rankOf() == 1)) { + // FIXME: do something here? + } else if (!shape::equalsSoft(shapeInfo(), other->shapeInfo())) + return false; + + if (isS()) { + // string is special case, we'll compare them one by one, considering both arrays are guaranteed to have the same length + + if (dataType() == DataType::UTF8) { + for (Nd4jLong e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); + + if (s1 != s2) + return false; + } + } + else if (dataType() == DataType::UTF16) { + for (Nd4jLong e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); + + if (s1 != s2) + return false; + } + } + else { + for (Nd4jLong e = 0; e < this->lengthOf(); e++) { + auto s1 = this->e(e); + auto s2 = other->e(e); + + if (s1 != s2) + return false; + } + } + + return true; + } else { + // regular numeric types + NDArray tmp(sd::DataType::FLOAT32, getContext()); // scalar = 0 + + ExtraArguments extras({0.0, 0.0, eps}); + + NDArray::prepareSpecialUse({&tmp}, {this, other}); + NativeOpExecutioner::execReduce3Scalar(getContext(), reduce3::EqualsWithEps, buffer(), shapeInfo(), + specialBuffer(), specialShapeInfo(), + extras.argumentsAsT(DataType::FLOAT32), other->buffer(), + other->shapeInfo(), other->specialBuffer(), + other->specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), + tmp.specialBuffer(), tmp.specialShapeInfo()); + NDArray::registerSpecialUse({&tmp}, {this, other}); + + synchronize("NDArray::equalsTo"); + + if (tmp.e(0) != 0) + return false; + + return true; + } } -} ////////////////////////////////////////////////////////////////////////// -template <> -std::string NDArray::e(const Nd4jLong i) const { + template <> + std::string NDArray::e(const Nd4jLong i) const { - if (!isS()) - throw std::runtime_error("Can't get std::string out of non-string array"); + if (!isS()) + throw std::runtime_error("Can't get std::string out of non-string array"); - if (i == lengthOf()) - throw std::runtime_error("Can't get std::string for index out of range"); + if (i == lengthOf()) + throw std::runtime_error("Can't get std::string for index out of range"); - if (this->dataType() == DataType::UTF16) { - auto u16 = this->e(i); - std::string s; - StringUtils::u16StringToU8String(u16, s); - return s; + if (this->dataType() == DataType::UTF16) { + auto u16 = this->e(i); + std::string s; + StringUtils::u16StringToU8String(u16, s); + return s; + } + + if (this->dataType() == DataType::UTF32) { + auto u32 = this->e(i); + std::string s; + StringUtils::u32StringToU8String(u32, s); + return s; + } + + NDArray::preparePrimaryUse({}, {this}); + + auto offsets = bufferAsT(); + auto offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + auto start = offsets[i]; + auto end = offsets[i + 1]; + auto data = bufferAsT() + offsetsLength + start; + + std::string r(reinterpret_cast(data), (end - start)); + + registerPrimaryUse({}, {this}); + + return r; } - if (this->dataType() == DataType::UTF32) { - auto u32 = this->e(i); - std::string s; - StringUtils::u32StringToU8String(u32, s); - return s; + template <> + std::u16string NDArray::e(const Nd4jLong i) const { + + if (!isS()) + throw std::runtime_error("Can't get std::u16string out of non-string array"); + + if(i == lengthOf()) + throw std::runtime_error("Can't get std::u16string for index out of range"); + + if (this->dataType() == DataType::UTF8) { + auto u = this->e(i); + std::u16string s; + StringUtils::u8StringToU16String(u, s); + return s; + } + + if (this->dataType() == DataType::UTF32) { + auto u32 = this->e(i); + std::u16string s; + StringUtils::u32StringToU16String(u32, s); + return s; + } + + NDArray::preparePrimaryUse({}, { this }); + + auto offsets = bufferAsT(); + Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + Nd4jLong start = offsets[i]; + Nd4jLong end = offsets[i + 1]; + auto data = bufferAsT() + offsetsLength + start; + + std::u16string r(reinterpret_cast(data), (end - start) / sizeof(char16_t)); + + registerPrimaryUse({}, { this }); + + return r; } - NDArray::preparePrimaryUse({}, {this}); + template <> + std::u32string NDArray::e(const Nd4jLong i) const { - auto offsets = bufferAsT(); - auto offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - auto start = offsets[i]; - auto end = offsets[i + 1]; - auto data = bufferAsT() + offsetsLength + start; + if (!isS()) + throw std::runtime_error("Can't get std::u32string out of non-string array"); - std::string r(reinterpret_cast(data), (end - start)); + if (i == lengthOf()) + throw std::runtime_error("Can't get std::u32string for index out of range"); - registerPrimaryUse({}, {this}); + if (this->dataType() == DataType::UTF8) { + auto u = this->e(i); + std::u32string s; + StringUtils::u8StringToU32String(u, s); + return s; + } - return r; -} + if (this->dataType() == DataType::UTF16) { + auto u16 = this->e(i); + std::u32string s; + StringUtils::u16StringToU32String(u16, s); + return s; + } -template <> -std::u16string NDArray::e(const Nd4jLong i) const { + NDArray::preparePrimaryUse({}, { this }); - if (!isS()) - throw std::runtime_error("Can't get std::u16string out of non-string array"); + auto offsets = bufferAsT(); + Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); + Nd4jLong start = offsets[i]; + Nd4jLong end = offsets[i + 1]; - if(i == lengthOf()) - throw std::runtime_error("Can't get std::u16string for index out of range"); + auto data = bufferAsT() + offsetsLength + start; - if (this->dataType() == DataType::UTF8) { - auto u = this->e(i); - std::u16string s; - StringUtils::u8StringToU16String(u, s); - return s; + std::u32string r(reinterpret_cast(data), (end - start) / sizeof(char32_t)); + + registerPrimaryUse({}, { this }); + + return r; } - if (this->dataType() == DataType::UTF32) { - auto u32 = this->e(i); - std::u16string s; - StringUtils::u32StringToU16String(u32, s); - return s; - } - - NDArray::preparePrimaryUse({}, { this }); - - auto offsets = bufferAsT(); - Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - Nd4jLong start = offsets[i]; - Nd4jLong end = offsets[i + 1]; - auto data = bufferAsT() + offsetsLength + start; - - std::u16string r(reinterpret_cast(data), (end - start) / sizeof(char16_t)); - - registerPrimaryUse({}, { this }); - - return r; -} - -template <> -std::u32string NDArray::e(const Nd4jLong i) const { - - if (!isS()) - throw std::runtime_error("Can't get std::u32string out of non-string array"); - - if (i == lengthOf()) - throw std::runtime_error("Can't get std::u32string for index out of range"); - - if (this->dataType() == DataType::UTF8) { - auto u = this->e(i); - std::u32string s; - StringUtils::u8StringToU32String(u, s); - return s; - } - - if (this->dataType() == DataType::UTF16) { - auto u16 = this->e(i); - std::u32string s; - StringUtils::u16StringToU32String(u16, s); - return s; - } - - NDArray::preparePrimaryUse({}, { this }); - - auto offsets = bufferAsT(); - Nd4jLong offsetsLength = ShapeUtils::stringBufferHeaderRequirements(lengthOf()); - Nd4jLong start = offsets[i]; - Nd4jLong end = offsets[i + 1]; - - auto data = bufferAsT() + offsetsLength + start; - - std::u32string r(reinterpret_cast(data), (end - start) / sizeof(char32_t)); - - registerPrimaryUse({}, { this }); - - return r; -} - ////////////////////////////////////////////////////////////////////////// -template <> -utf8string NDArray::e(const Nd4jLong i) const { + template <> + utf8string NDArray::e(const Nd4jLong i) const { - if (!isS()) - throw std::runtime_error("This method is available for String arrays only"); + if (!isS()) + throw std::runtime_error("This method is available for String arrays only"); - auto rp = getOffset(i); + auto rp = getOffset(i); - syncToHost(); - tickReadHost(); + syncToHost(); + tickReadHost(); - return *(reinterpret_cast(buffer())[rp]); -} + return *(reinterpret_cast(buffer())[rp]); + } ///////////////////////////////////////////////////////////////////////// -template -T NDArray::e(const Nd4jLong i) const { + template + T NDArray::e(const Nd4jLong i) const { - const auto rp = getOffset(i); + const auto rp = getOffset(i); - NDArray::preparePrimaryUse({}, {this}); - NDArray::registerPrimaryUse({}, {this}); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), rp), LIBND4J_TYPES); + NDArray::preparePrimaryUse({}, {this}); + NDArray::registerPrimaryUse({}, {this}); + BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), rp), LIBND4J_TYPES); -} -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong) const, LIBND4J_TYPES); + } + BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // Returns value from 2D matrix by coordinates/indexes -template -T NDArray::e(const Nd4jLong i, const Nd4jLong j) const { + template + T NDArray::e(const Nd4jLong i, const Nd4jLong j) const { - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray::e(i,j): one of input indexes is out of array length or rank!=2 !"); + if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) + throw std::invalid_argument("NDArray::e(i,j): one of input indexes is out of array length or rank!=2 !"); - const auto xOffset = i * strideAt(0) + j * strideAt(1); + const auto xOffset = i * strideAt(0) + j * strideAt(1); - NDArray::preparePrimaryUse({}, {this}); - NDArray::registerPrimaryUse({}, {this}); + NDArray::preparePrimaryUse({}, {this}); + NDArray::registerPrimaryUse({}, {this}); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); + BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); - return static_cast(119); -} -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); + return static_cast(119); + } + BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // returns value from 3D tensor by coordinates -template -T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { + template + T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k) const { - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) - throw std::invalid_argument("NDArray::e(i,j,k): one of input indexes is out of array length or rank!=3 !"); + if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) + throw std::invalid_argument("NDArray::e(i,j,k): one of input indexes is out of array length or rank!=3 !"); - const auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2); + const auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2); - NDArray::preparePrimaryUse({}, {this}); - NDArray::registerPrimaryUse({}, {this}); + NDArray::preparePrimaryUse({}, {this}); + NDArray::registerPrimaryUse({}, {this}); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); + BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); - return static_cast(119); -} -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); + return static_cast(119); + } + BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// // returns value from 3D tensor by coordinates -template -T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l) const { + template + T NDArray::e(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l) const { - if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2] || l >= shapeOf()[3]) - throw std::invalid_argument("NDArray::e(i,j,k,l): one of input indexes is out of array length or rank!=4 !"); + if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2] || l >= shapeOf()[3]) + throw std::invalid_argument("NDArray::e(i,j,k,l): one of input indexes is out of array length or rank!=4 !"); - const auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + l * strideAt(3); + const auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + l * strideAt(3); - NDArray::preparePrimaryUse({}, {this}); - NDArray::registerPrimaryUse({}, {this}); + NDArray::preparePrimaryUse({}, {this}); + NDArray::registerPrimaryUse({}, {this}); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); + BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), return templatedGet<, T>(buffer(), xOffset), LIBND4J_TYPES); - return static_cast(119); -} -BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); + return static_cast(119); + } + BUILD_SINGLE_UNCHAINED_TEMPLATE(template ND4J_EXPORT , NDArray::e(const Nd4jLong, const Nd4jLong, const Nd4jLong, const Nd4jLong) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -NDArray NDArray::e(const Nd4jLong i) const { + NDArray NDArray::e(const Nd4jLong i) const { - const auto offset = getOffset(i); + const auto offset = getOffset(i); - NDArray scalar(dataType(), getContext()); + NDArray scalar(dataType(), getContext()); - scalar.copyBuffersContinuouslyFrom(*this, sizeOfT(), 0, bufferOffset() + offset); + scalar.copyBuffersContinuouslyFrom(*this, sizeOfT(), 0, bufferOffset() + offset); - return scalar; -} + return scalar; + } ////////////////////////////////////////////////////////////////////////// // perform array transformation -void NDArray::applyTransform(sd::transform::FloatOps op, NDArray& target, ExtraArguments *extraParams) { + void NDArray::applyTransform(sd::transform::FloatOps op, NDArray& target, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::applyTransform FloatOps: you can't use this method on String array!"); + if (isS()) + throw std::runtime_error("NDArray::applyTransform FloatOps: you can't use this method on String array!"); - if (!target.isR()) - throw std::runtime_error("NDArray::applyTransform FloatOps: target array must have one of FLOAT types"); + if (!target.isR()) + throw std::runtime_error("NDArray::applyTransform FloatOps: target array must have one of FLOAT types"); - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); -} + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); + } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(sd::transform::AnyOps op, NDArray& target, ExtraArguments *extraParams) { + void NDArray::applyTransform(sd::transform::AnyOps op, NDArray& target, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::applyTransform AnyOps: you can't use this method on String array!"); + if (isS()) + throw std::runtime_error("NDArray::applyTransform AnyOps: you can't use this method on String array!"); - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); -} + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformAny(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); + } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(sd::transform::SameOps op, NDArray& target, ExtraArguments *extraParams) { + void NDArray::applyTransform(sd::transform::SameOps op, NDArray& target, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::applyTransform SameOps: you can't use this method on String array!"); + if (isS()) + throw std::runtime_error("NDArray::applyTransform SameOps: you can't use this method on String array!"); - if (target.dataType() != dataType()) - throw std::runtime_error("NDArray::applyTransform SameOps: target array must have the same data type as original array"); + if (target.dataType() != dataType()) + throw std::runtime_error("NDArray::applyTransform SameOps: target array must have the same data type as original array"); - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); -} + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); + } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(sd::transform::StrictOps op, NDArray& target, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::applyTransform StrictOps: you can't use this method on String array!"); + void NDArray::applyTransform(sd::transform::StrictOps op, NDArray& target, ExtraArguments *extraParams) { + if (isS()) + throw std::runtime_error("NDArray::applyTransform StrictOps: you can't use this method on String array!"); - if (!this->isR() || !target.isR() || (this->dataType() != target.dataType())) - throw std::runtime_error("NDArray::applyTransform StrictOps: both Source and Target array must have same FLOAT type !"); + if (!this->isR() || !target.isR() || (this->dataType() != target.dataType())) + throw std::runtime_error("NDArray::applyTransform StrictOps: both Source and Target array must have same FLOAT type !"); - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); -} + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); + } //////////////////////////////////////////////////////////////////////// -void NDArray::applyTransform(sd::transform::BoolOps op, NDArray& target, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::applyTransform BoolOps: you can't use this method on String array!"); + void NDArray::applyTransform(sd::transform::BoolOps op, NDArray& target, ExtraArguments *extraParams) { + if (isS()) + throw std::runtime_error("NDArray::applyTransform BoolOps: you can't use this method on String array!"); - if (!target.isB()) - throw std::runtime_error("NDArray::applyTransform BoolOps: target array must have one of BOOL types"); + if (!target.isB()) + throw std::runtime_error("NDArray::applyTransform BoolOps: target array must have one of BOOL types"); - NDArray::prepareSpecialUse({&target}, {this}); - NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this}); -} + NDArray::prepareSpecialUse({&target}, {this}); + NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()) : nullptr, nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this}); + } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::FloatOps op, void *extraParams) const & { - if (isS()) - throw std::runtime_error("NDArray::transform FloatOps: you can't use this method on String array!"); + NDArray NDArray::transform(sd::transform::FloatOps op, void *extraParams) const & { + if (isS()) + throw std::runtime_error("NDArray::transform FloatOps: you can't use this method on String array!"); - NDArray result(ordering(), getShapeAsVector(), DataTypeUtils::pickFloatingType(dataType()), getContext()); + NDArray result(ordering(), getShapeAsVector(), DataTypeUtils::pickFloatingType(dataType()), getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({&result}, {this}); - return result; -} + return result; + } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::FloatOps op, void *extraParams) && { - if (isS()) - throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); + NDArray NDArray::transform(sd::transform::FloatOps op, void *extraParams) && { + if (isS()) + throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); - return std::move(*this); -} + return std::move(*this); + } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::SameOps op, void *extraParams) const & { - if (isS()) - throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); + NDArray NDArray::transform(sd::transform::SameOps op, void *extraParams) const & { + if (isS()) + throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); - NDArray result(shapeInfo(), false, getContext()); + NDArray result(shapeInfo(), false, getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({&result}, {this}); - return result; -} + return result; + } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::SameOps op, void *extraParams) && { - if (isS()) - throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); + NDArray NDArray::transform(sd::transform::SameOps op, void *extraParams) && { + if (isS()) + throw std::runtime_error("NDArray::transform SameOps: you can't use this method on String array!"); - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); - return std::move(*this); -} + return std::move(*this); + } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::StrictOps op, void *extraParams) const & { - if (!this->isR()) - throw std::runtime_error("Source array must have one of FLOAT types"); + NDArray NDArray::transform(sd::transform::StrictOps op, void *extraParams) const & { + if (!this->isR()) + throw std::runtime_error("Source array must have one of FLOAT types"); - NDArray result(shapeInfo(), false, getContext()); + NDArray result(shapeInfo(), false, getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({&result}, {this}); - return result; -} + return result; + } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::StrictOps op, void *extraParams) && { - if (!this->isR()) - throw std::runtime_error("Source array must have one of FLOAT types"); + NDArray NDArray::transform(sd::transform::StrictOps op, void *extraParams) && { + if (!this->isR()) + throw std::runtime_error("Source array must have one of FLOAT types"); - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformStrict(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); - return std::move(*this); -} + return std::move(*this); + } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::BoolOps op, void *extraParams) const & { - if (isS()) - throw std::runtime_error("NDArray::transform BoolOps: you can't use this method on String array!"); + NDArray NDArray::transform(sd::transform::BoolOps op, void *extraParams) const & { + if (isS()) + throw std::runtime_error("NDArray::transform BoolOps: you can't use this method on String array!"); - NDArray result(ordering(), getShapeAsVector(), sd::DataType::BOOL, getContext()); + NDArray result(ordering(), getShapeAsVector(), sd::DataType::BOOL, getContext()); - NDArray::prepareSpecialUse({&result}, {this}); - NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({&result}, {this}); + NDArray::prepareSpecialUse({&result}, {this}); + NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({&result}, {this}); - return result; -} + return result; + } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::transform(sd::transform::BoolOps op, void *extraParams) && { - if (isS()) - throw std::runtime_error("NDArray::transform BoolOps: you can't use this method on String array!"); + NDArray NDArray::transform(sd::transform::BoolOps op, void *extraParams) && { + if (isS()) + throw std::runtime_error("NDArray::transform BoolOps: you can't use this method on String array!"); - NDArray::prepareSpecialUse({this}, {this}); - NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); - NDArray::registerSpecialUse({this}, {this}); + NDArray::prepareSpecialUse({this}, {this}); + NativeOpExecutioner::execTransformBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), extraParams, nullptr, nullptr); + NDArray::registerSpecialUse({this}, {this}); - return std::move(*this); -} + return std::move(*this); + } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(sd::scalar::Ops op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams) { - if (isS()) - throw std::runtime_error("NDArray::applyScalarArr: you can't use this method on String array!"); - if (scalar.lengthOf() != 1) - throw std::invalid_argument("NDArray::applyScalarArr method: operand is not a scalar!"); + void NDArray::applyScalarArr(sd::scalar::Ops op, const NDArray& scalar, NDArray& target, ExtraArguments *extraParams) { + if (isS()) + throw std::runtime_error("NDArray::applyScalarArr: you can't use this method on String array!"); + if (scalar.lengthOf() != 1) + throw std::invalid_argument("NDArray::applyScalarArr method: operand is not a scalar!"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar.shapeInfo()) && !(target.dataType() == dataType() || target.dataType() == scalar.dataType())) - throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!"); + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(shapeInfo(), scalar.shapeInfo()) && !(target.dataType() == dataType() || target.dataType() == scalar.dataType())) + throw std::invalid_argument("NDArray::applyScalarArr method: wrong type of target array!"); - NDArray::prepareSpecialUse({&target}, {this, &scalar}); - NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); - NDArray::registerSpecialUse({&target}, {this, &scalar}); -} + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); + } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(sd::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyScalarArr BoolOps: you can't use this method on String array!"); - if (!target.isB()) - throw std::invalid_argument("NDArray::applyScalarArr bool method: target has not bool type!"); - if (dataType() != scalar.dataType()) { - nd4j_printf("NDArray::applyScalarArr BoolOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar.dataType()); - throw std::invalid_argument("NDArray::applyScalarArr bool method: this and scalar arrays must have the same type!"); - } + void NDArray::applyScalarArr(sd::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::applyScalarArr BoolOps: you can't use this method on String array!"); + if (!target.isB()) + throw std::invalid_argument("NDArray::applyScalarArr bool method: target has not bool type!"); + if (dataType() != scalar.dataType()) { + nd4j_printf("NDArray::applyScalarArr BoolOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar.dataType()); + throw std::invalid_argument("NDArray::applyScalarArr bool method: this and scalar arrays must have the same type!"); + } - NDArray::prepareSpecialUse({&target}, {this, &scalar}); - NativeOpExecutioner::execScalarBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); - NDArray::registerSpecialUse({&target}, {this, &scalar}); -} + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalarBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); + } ////////////////////////////////////////////////////////////////////////// -void NDArray::applyScalarArr(sd::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyScalarArr IntOps: you can't use this method on String array!"); + void NDArray::applyScalarArr(sd::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::applyScalarArr IntOps: you can't use this method on String array!"); - if (target.dataType() != this->dataType()) - throw std::invalid_argument("NDArray::applyScalarArr int method: target has not bool type!"); - if (dataType() != scalar.dataType()) { - nd4j_printf("NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar.dataType()); - throw std::invalid_argument("NDArray::applyScalarArr int method: this and scalar arrays must have the same type!"); + if (target.dataType() != this->dataType()) + throw std::invalid_argument("NDArray::applyScalarArr int method: target has not bool type!"); + if (dataType() != scalar.dataType()) { + nd4j_printf("NDArray::applyScalarArr IntOps: this dtype: [%i]; scalar dtype: [%i]\n", this->dataType(), scalar.dataType()); + throw std::invalid_argument("NDArray::applyScalarArr int method: this and scalar arrays must have the same type!"); + } + + NDArray::prepareSpecialUse({&target}, {this, &scalar}); + NativeOpExecutioner::execScalarInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); + NDArray::registerSpecialUse({&target}, {this, &scalar}); } - NDArray::prepareSpecialUse({&target}, {this, &scalar}); - NativeOpExecutioner::execScalarInt(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), scalar.buffer(), scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), extraParams != nullptr ? extraParams->argumentsAsT(target.dataType()): nullptr); - NDArray::registerSpecialUse({&target}, {this, &scalar}); -} - //////////////////////////////////////////////////////////////////////// -template -void NDArray::applyScalar(sd::scalar::IntOps op, const T scalar, NDArray& target, ExtraArguments *extraParams) const { + template + void NDArray::applyScalar(sd::scalar::IntOps op, const T scalar, NDArray& target, ExtraArguments *extraParams) const { - NDArray scalarArr = NDArrayFactory::create(this->dataType(), scalar, getContext()); - applyScalarArr(op, scalarArr, target, extraParams); -} - -template <> ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::applyScalar(sd::scalar::Ops op, const T scalar, NDArray& target, ExtraArguments *extraParams) { - - auto scalarArr = NDArrayFactory::create(dataType(), scalar, this->getContext()); - applyScalarArr(op, scalarArr, target, extraParams); -} -template <> ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const double scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const float scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const float16 scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams); -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const bool scalar, NDArray &target, ExtraArguments *extraParams); - -//////////////////////////////////////////////////////////////////////// -template -void NDArray::applyScalar(sd::scalar::BoolOps op, const T scalar, NDArray &target, ExtraArguments *extraParams) const { - - NDArray scalarArr = NDArrayFactory::create(scalar, getContext()); - applyScalarArr(op, scalarArr, target, extraParams); -} - -template <> ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; -template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; - -//////////////////////////////////////////////////////////////////////// -void NDArray::applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, const std::vector& dimensions, const ExtraArguments *extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyIndexReduce: you can't use this method on String array!"); - - if (target.dataType() != sd::DataType::INT64 && target.dataType() != sd::DataType::INT32) - throw std::runtime_error("NDArray::applyIndexReduce operations return INT32/INT64"); - - void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(this->dataType()) : nullptr; - - NDArray::prepareSpecialUse({&target}, {this}); - - if (target.lengthOf() == 1) { - NativeOpExecutioner::execIndexReduceScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - std::vector copy = dimensions; - shape::checkDimensions(rankOf(), copy); - auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy); - NativeOpExecutioner::execIndexReduce(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); - synchronize("NDArray::applyIndexReduce"); + NDArray scalarArr = NDArrayFactory::create(this->dataType(), scalar, getContext()); + applyScalarArr(op, scalarArr, target, extraParams); } - registerSpecialUse({&target}, {this}); -} + template <> ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::IntOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; + +//////////////////////////////////////////////////////////////////////// + template + void NDArray::applyScalar(sd::scalar::Ops op, const T scalar, NDArray& target, ExtraArguments *extraParams) { + + auto scalarArr = NDArrayFactory::create(dataType(), scalar, this->getContext()); + applyScalarArr(op, scalarArr, target, extraParams); + } + template <> ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const double scalar, NDArray &target, ExtraArguments *extraParams); + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const float scalar, NDArray &target, ExtraArguments *extraParams); + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const float16 scalar, NDArray &target, ExtraArguments *extraParams); + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams); + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams); + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int scalar, NDArray &target, ExtraArguments *extraParams); + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams); + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams); + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams); + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::Ops op, const bool scalar, NDArray &target, ExtraArguments *extraParams); + +//////////////////////////////////////////////////////////////////////// + template + void NDArray::applyScalar(sd::scalar::BoolOps op, const T scalar, NDArray &target, ExtraArguments *extraParams) const { + + NDArray scalarArr = NDArrayFactory::create(scalar, getContext()); + applyScalarArr(op, scalarArr, target, extraParams); + } + + template <> ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const NDArray& scalar, NDArray &target, ExtraArguments *extraParams) const { throw std::runtime_error("NDArray::applyScalar method: do not use me!");} + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const double scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const float scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const float16 scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const bfloat16 scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const Nd4jLong scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int16_t scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const int8_t scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const uint8_t scalar, NDArray &target, ExtraArguments *extraParams) const; + template ND4J_EXPORT void NDArray::applyScalar(sd::scalar::BoolOps op, const bool scalar, NDArray &target, ExtraArguments *extraParams) const; + +//////////////////////////////////////////////////////////////////////// + void NDArray::applyIndexReduce(sd::indexreduce::Ops op, NDArray& target, const std::vector& dimensions, const ExtraArguments *extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::applyIndexReduce: you can't use this method on String array!"); + + if (target.dataType() != sd::DataType::INT64 && target.dataType() != sd::DataType::INT32) + throw std::runtime_error("NDArray::applyIndexReduce operations return INT32/INT64"); + + void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(this->dataType()) : nullptr; + + NDArray::prepareSpecialUse({&target}, {this}); + + if (target.lengthOf() == 1) { + NativeOpExecutioner::execIndexReduceScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + } + else { + std::vector copy = dimensions; + shape::checkDimensions(rankOf(), copy); + auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy); + NativeOpExecutioner::execIndexReduce(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets()); + synchronize("NDArray::applyIndexReduce"); + } + + registerSpecialUse({&target}, {this}); + } //////////////////////////////////////////////////////////////////////// // reduce dimensions in this array relying on index operations -NDArray NDArray::applyIndexReduce(sd::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments* extraParams ) const { + NDArray NDArray::applyIndexReduce(sd::indexreduce::Ops op, const std::vector& dimensions, const ExtraArguments* extraParams ) const { - std::vector copy = dimensions; - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, false, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); + std::vector copy = dimensions; + auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataType::INT64, false, false, getContext()->getWorkspace()); + NDArray result(newShape, true, getContext()); - applyIndexReduce(op, result, copy, extraParams); + applyIndexReduce(op, result, copy, extraParams); - return result; -} + return result; + } //////////////////////////////////////////////////////////////////////// // apply reduce3 operations to this and other array, return result in new output array -NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, const ExtraArguments* extraParams) const { + NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, const ExtraArguments* extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyReduce3 method: you can't use this method on String array!"); - if(dataType() != other.dataType()) - throw std::runtime_error("NDArray::applyReduce3 method: the types of this and other arrays must be the same !"); - // check shapes consistency - if(!isSameShape(other)) - throw std::runtime_error("NDArray::applyReduce3 method: the shapes of this and other arrays must be the same !"); - // create shapeInfo for scalar - auto newShape = ShapeBuilders::createScalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()), getContext()->getWorkspace()); - // create output array (scalar) - NDArray result(newShape, true, getContext()); - RELEASE(newShape, getContext()->getWorkspace()); - // create dynamic array of extra parameters if array extraParams is empty (==nullptr) - void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; + if (isS()) + throw std::runtime_error("NDArray::applyReduce3 method: you can't use this method on String array!"); + if(dataType() != other.dataType()) + throw std::runtime_error("NDArray::applyReduce3 method: the types of this and other arrays must be the same !"); + // check shapes consistency + if(!isSameShape(other)) + throw std::runtime_error("NDArray::applyReduce3 method: the shapes of this and other arrays must be the same !"); + // create shapeInfo for scalar + auto newShape = ShapeBuilders::createScalarShapeInfo(DataTypeUtils::pickFloatingType(dataType()), getContext()->getWorkspace()); + // create output array (scalar) + NDArray result(newShape, true, getContext()); + RELEASE(newShape, getContext()->getWorkspace()); + // create dynamic array of extra parameters if array extraParams is empty (==nullptr) + void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execReduce3Scalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); - NDArray::registerSpecialUse({&result}, {this, &other}); + NDArray::prepareSpecialUse({&result}, {this, &other}); + NativeOpExecutioner::execReduce3Scalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + NDArray::registerSpecialUse({&result}, {this, &other}); - return result; -} + return result; + } //////////////////////////////////////////////////////////////////////// // apply reduce3 (exec) operations to this and other array, return result in new output array -NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams) const { + NDArray NDArray::applyReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyReduce3: you can't use this method on String array!"); - if(dataType() != other.dataType()) - throw std::runtime_error("NDArray::applyReduce3 method: the types of this and other arrays must be the same !"); + if (isS()) + throw std::runtime_error("NDArray::applyReduce3: you can't use this method on String array!"); + if(dataType() != other.dataType()) + throw std::runtime_error("NDArray::applyReduce3 method: the types of this and other arrays must be the same !"); - std::vector copy(dimensions); - shape::checkDimensions(rankOf(), copy); - shape::checkDimensions(other.rankOf(), copy); + std::vector copy(dimensions); + shape::checkDimensions(rankOf(), copy); + shape::checkDimensions(other.rankOf(), copy); - auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, false, getContext()->getWorkspace()); - NDArray result(newShape, true, getContext()); - // create temporary dynamic array of extra parameters if array extraParams is empty (==nullptr) - void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; + auto newShape = ShapeUtils::evalReduceShapeInfo('c', copy, *this, DataTypeUtils::pickFloatingType(dataType()), false, false, getContext()->getWorkspace()); + NDArray result(newShape, true, getContext()); + // create temporary dynamic array of extra parameters if array extraParams is empty (==nullptr) + void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - NDArray::prepareSpecialUse({&result}, {this, &other}); + NDArray::prepareSpecialUse({&result}, {this, &other}); - // perform calculations - if(rankOf() == copy.size() && other.rankOf() == copy.size()) { - NativeOpExecutioner::execReduce3Scalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + // perform calculations + if(rankOf() == copy.size() && other.rankOf() == copy.size()) { + NativeOpExecutioner::execReduce3Scalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo()); + } + else { + + auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(other.shapeInfo(), copy); + + if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo()) || (packX.numberOfTads() != packY.numberOfTads() && packX.numberOfTads() != 1 && packY.numberOfTads() != 1)) + throw std::runtime_error("NDArray::applyReduce3 cuda method: arrays tads are inconsistent !"); + + NativeOpExecutioner::execReduce3(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + } + + registerSpecialUse({&result}, {this, &other}); + + return result; } - else { - - auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; - - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy); - auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(other.shapeInfo(), copy); - - if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo()) || (packX.numberOfTads() != packY.numberOfTads() && packX.numberOfTads() != 1 && packY.numberOfTads() != 1)) - throw std::runtime_error("NDArray::applyReduce3 cuda method: arrays tads are inconsistent !"); - - NativeOpExecutioner::execReduce3(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); - } - - registerSpecialUse({&result}, {this, &other}); - - return result; -} //////////////////////////////////////////////////////////////////////// // apply reduce3 (execAll) operations to this and other array, return result in new output array -NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams) const { - if (isS()) - throw std::runtime_error("NDArray::applyAllReduce3: you can't use this method on String array!"); - if(dataType() != other.dataType()) - throw std::runtime_error("NDArray::applyAllReduce3 method: the types of this and other arrays must be the same !"); + NDArray NDArray::applyAllReduce3(sd::reduce3::Ops op, const NDArray& other, const std::vector& dimensions, const ExtraArguments* extraParams) const { + if (isS()) + throw std::runtime_error("NDArray::applyAllReduce3: you can't use this method on String array!"); + if(dataType() != other.dataType()) + throw std::runtime_error("NDArray::applyAllReduce3 method: the types of this and other arrays must be the same !"); - // be careful, copy array may undergo changes (sort, transformation of negative dimensions to positive, duplicates removing ) - std::vector copy(dimensions); - shape::checkDimensions(rankOf(), copy); - shape::checkDimensions(other.rankOf(), copy); + // be careful, copy array may undergo changes (sort, transformation of negative dimensions to positive, duplicates removing ) + std::vector copy(dimensions); + shape::checkDimensions(rankOf(), copy); + shape::checkDimensions(other.rankOf(), copy); - auto packX = ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy); - auto packY = ConstantTadHelper::getInstance().tadForDimensions(other.shapeInfo(), copy); + auto packX = ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), copy); + auto packY = ConstantTadHelper::getInstance().tadForDimensions(other.shapeInfo(), copy); - // check tads shapes - if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo())) - throw std::runtime_error("NDArray::applyAllReduce3 method: the shapes of array tads are different !"); + // check tads shapes + if(!shape::equalsSoft(packX.primaryShapeInfo(), packY.primaryShapeInfo())) + throw std::runtime_error("NDArray::applyAllReduce3 method: the shapes of array tads are different !"); - // set newShape for output array - auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(DataTypeUtils::pickFloatingType(dataType()), 'c', {packX.numberOfTads(), packY.numberOfTads()}); + // set newShape for output array + auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(DataTypeUtils::pickFloatingType(dataType()), 'c', {packX.numberOfTads(), packY.numberOfTads()}); - // create output array - NDArray result(newShape, true, getContext()); + // create output array + NDArray result(newShape, true, getContext()); - // create dynamic array of extra parameters if array extraParams is empty (==nullptr) - void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; + // create dynamic array of extra parameters if array extraParams is empty (==nullptr) + void* params = extraParams != nullptr ? const_cast(extraParams)->argumentsAsT(dataType()) : nullptr; - auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; + auto pDims = sd::Environment::getInstance().isCPU() ? copy.data() : nullptr; - NDArray::prepareSpecialUse({&result}, {this, &other}); - NativeOpExecutioner::execReduce3All(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); - NDArray::registerSpecialUse({&result}, {this, &other}); + NDArray::prepareSpecialUse({&result}, {this, &other}); + NativeOpExecutioner::execReduce3All(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), params, other.buffer(), other.shapeInfo(), other.specialBuffer(), other.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), pDims, copy.size(), packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + NDArray::registerSpecialUse({&result}, {this, &other}); - return result; -} + return result; + } ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool checkTargetShape) const { + void NDArray::reduceAlongDimension(sd::reduce::FloatOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool checkTargetShape) const { - if (isS()) - throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: you can't use this method on String array!"); - if (!target.isR()) - throw std::invalid_argument("NDArray::reduceAlongDimension FloatOps: requires target array to be present and have type form real space!"); + if (isS()) + throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: you can't use this method on String array!"); + if (!target.isR()) + throw std::invalid_argument("NDArray::reduceAlongDimension FloatOps: requires target array to be present and have type form real space!"); - std::vector copy(dimensions); + std::vector copy(dimensions); - if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target.shapeInfo())) - throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: wrong target shape!"); - } - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(),nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - const Nd4jLong* zShapeInfoH = target.shapeInfo(); - const Nd4jLong* zShapeInfoD = target.specialShapeInfo(); - - if(rankOf() - dimensions.size() != target.rankOf()) { - auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace()); - zShapeInfoH = reinterpret_cast(zPack.primary()); - zShapeInfoD = reinterpret_cast(zPack.special()); + if(checkTargetShape) { + auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace()); + if(!shape::shapeEquals(newShape, target.shapeInfo())) + throw std::runtime_error("NDArray::reduceAlongDimension FloatOps: wrong target shape!"); } - std::vector dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy); - NativeOpExecutioner::execReduceFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size()); + NDArray::prepareSpecialUse({&target}, {this}); + if(rankOf() == copy.size() || copy.empty()) { + NativeOpExecutioner::execReduceFloatScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(),nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + } + else { + const Nd4jLong* zShapeInfoH = target.shapeInfo(); + const Nd4jLong* zShapeInfoD = target.specialShapeInfo(); + + if(rankOf() - dimensions.size() != target.rankOf()) { + auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace()); + zShapeInfoH = reinterpret_cast(zPack.primary()); + zShapeInfoD = reinterpret_cast(zPack.special()); + } + + std::vector dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy); + NativeOpExecutioner::execReduceFloat(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size()); + + } + synchronize("NDArray::reduceAlongDimension FloatOps"); + + NDArray::registerSpecialUse({&target}, {this}); } - synchronize("NDArray::reduceAlongDimension FloatOps"); - - NDArray::registerSpecialUse({&target}, {this}); -} ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool checkTargetShape) const { + void NDArray::reduceAlongDimension(sd::reduce::SameOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool checkTargetShape) const { - if (isS()) - throw std::runtime_error("NDArray::reduceAlongDimension SameOps: you can't use this method on String array!"); - if (target.dataType() != dataType()) - throw std::runtime_error("NDArray::reduceAlongDimension SameOps: requires target array to be present and have same dtype as input"); + if (isS()) + throw std::runtime_error("NDArray::reduceAlongDimension SameOps: you can't use this method on String array!"); + if (target.dataType() != dataType()) + throw std::runtime_error("NDArray::reduceAlongDimension SameOps: requires target array to be present and have same dtype as input"); - std::vector copy(dimensions); + std::vector copy(dimensions); - if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target.shapeInfo())) - throw std::runtime_error("NDArray::reduceAlongDimension SameOps: wrong target shape!"); - } - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - - const Nd4jLong* zShapeInfoH = target.shapeInfo(); - const Nd4jLong* zShapeInfoD = target.specialShapeInfo(); - - if(rankOf() - dimensions.size() != target.rankOf()) { - auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace()); - zShapeInfoH = reinterpret_cast(zPack.primary()); - zShapeInfoD = reinterpret_cast(zPack.special()); + if(checkTargetShape) { + auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace()); + if(!shape::shapeEquals(newShape, target.shapeInfo())) + throw std::runtime_error("NDArray::reduceAlongDimension SameOps: wrong target shape!"); } - std::vector dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy); - NativeOpExecutioner::execReduceSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size()); - } - synchronize("NDArray::reduceAlongDimension SameOps"); + NDArray::prepareSpecialUse({&target}, {this}); - NDArray::registerSpecialUse({&target}, {this}); -} + if(rankOf() == copy.size() || copy.empty()) { + NativeOpExecutioner::execReduceSameScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + } + else { + + const Nd4jLong* zShapeInfoH = target.shapeInfo(); + const Nd4jLong* zShapeInfoD = target.specialShapeInfo(); + + if(rankOf() - dimensions.size() != target.rankOf()) { + auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace()); + zShapeInfoH = reinterpret_cast(zPack.primary()); + zShapeInfoD = reinterpret_cast(zPack.special()); + } + + std::vector dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy); + NativeOpExecutioner::execReduceSame(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size()); + } + synchronize("NDArray::reduceAlongDimension SameOps"); + + NDArray::registerSpecialUse({&target}, {this}); + } ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool checkTargetShape) const { + void NDArray::reduceAlongDimension(sd::reduce::LongOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool checkTargetShape) const { - if (isS()) - throw std::runtime_error("NDArray::reduceAlongDimension LongOps: you can't use this method on String array!"); - if (target.dataType() != DataType::INT64) - throw std::runtime_error("NDArray::reduceAlongDimension LongOps: requires target array to be present and have type of INT64"); + if (isS()) + throw std::runtime_error("NDArray::reduceAlongDimension LongOps: you can't use this method on String array!"); + if (target.dataType() != DataType::INT64) + throw std::runtime_error("NDArray::reduceAlongDimension LongOps: requires target array to be present and have type of INT64"); - std::vector copy(dimensions); + std::vector copy(dimensions); - if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target.shapeInfo())) - throw std::runtime_error("NDArray::reduceAlongDimension LongOps: wrong target shape!"); - } - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - const Nd4jLong* zShapeInfoH = target.shapeInfo(); - const Nd4jLong* zShapeInfoD = target.specialShapeInfo(); - - if(rankOf() - dimensions.size() != target.rankOf()) { - auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace()); - zShapeInfoH = reinterpret_cast(zPack.primary()); - zShapeInfoD = reinterpret_cast(zPack.special()); + if(checkTargetShape) { + auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace()); + if(!shape::shapeEquals(newShape, target.shapeInfo())) + throw std::runtime_error("NDArray::reduceAlongDimension LongOps: wrong target shape!"); } - std::vector dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy); - NativeOpExecutioner::execReduceLong(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size()); - } - synchronize("NDArray::reduceAlongDimension LongOps"); + NDArray::prepareSpecialUse({&target}, {this}); - NDArray::registerSpecialUse({&target}, {this}); -} + if(rankOf() == copy.size() || copy.empty()) { + NativeOpExecutioner::execReduceLongScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + } + else { + const Nd4jLong* zShapeInfoH = target.shapeInfo(); + const Nd4jLong* zShapeInfoD = target.specialShapeInfo(); + + if(rankOf() - dimensions.size() != target.rankOf()) { + auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace()); + zShapeInfoH = reinterpret_cast(zPack.primary()); + zShapeInfoD = reinterpret_cast(zPack.special()); + } + + std::vector dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy); + NativeOpExecutioner::execReduceLong(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size()); + } + synchronize("NDArray::reduceAlongDimension LongOps"); + + NDArray::registerSpecialUse({&target}, {this}); + } ////////////////////////////////////////////////////////////////////////// // method reduces array by excluding its shapes along axes present in dimensions vector -void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool checkTargetShape) const { + void NDArray::reduceAlongDimension(sd::reduce::BoolOps op, NDArray& target, const std::vector& dimensions, const bool keepDims, const bool checkTargetShape) const { - if (isS()) - throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: you can't use this method on String array!"); - if (!target.isB()) - throw std::invalid_argument("NDArray::reduceAlongDimension BoolOps cuda: requires target array to be present and have BOOL type!"); + if (isS()) + throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: you can't use this method on String array!"); + if (!target.isB()) + throw std::invalid_argument("NDArray::reduceAlongDimension BoolOps cuda: requires target array to be present and have BOOL type!"); - std::vector copy(dimensions); + std::vector copy(dimensions); - if(checkTargetShape) { - auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace()); - if(!shape::shapeEquals(newShape, target.shapeInfo())) - throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: wrong target shape!"); - } - - NDArray::prepareSpecialUse({&target}, {this}); - - if(rankOf() == copy.size() || copy.empty()) { - NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); - } - else { - const Nd4jLong* zShapeInfoH = target.shapeInfo(); - const Nd4jLong* zShapeInfoD = target.specialShapeInfo(); - - if(rankOf() - dimensions.size() != target.rankOf()) { - auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace()); - zShapeInfoH = reinterpret_cast(zPack.primary()); - zShapeInfoD = reinterpret_cast(zPack.special()); + if(checkTargetShape) { + auto newShape = ShapeUtils::evalReduceShapeInfo(target.ordering(), copy, *this, keepDims, false, getContext()->getWorkspace()); + if(!shape::shapeEquals(newShape, target.shapeInfo())) + throw std::runtime_error("NDArray::reduceAlongDimension BoolOps cuda: wrong target shape!"); } - std::vector dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy); - NativeOpExecutioner::execReduceBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size()); - } - synchronize("NDArray::reduceAlongDimension LongOps"); + NDArray::prepareSpecialUse({&target}, {this}); - NDArray::registerSpecialUse({&target}, {this}); -} + if(rankOf() == copy.size() || copy.empty()) { + NativeOpExecutioner::execReduceBoolScalar(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo()); + } + else { + const Nd4jLong* zShapeInfoH = target.shapeInfo(); + const Nd4jLong* zShapeInfoD = target.specialShapeInfo(); + + if(rankOf() - dimensions.size() != target.rankOf()) { + auto zPack = ConstantShapeHelper::getInstance().createShapeInfoWithNoUnitiesForReduce(target.shapeInfo(), copy, target.getContext()->getWorkspace()); + zShapeInfoH = reinterpret_cast(zPack.primary()); + zShapeInfoD = reinterpret_cast(zPack.special()); + } + + std::vector dims = ShapeUtils::evalDimsForReduceOp(rankOf(), copy); + NativeOpExecutioner::execReduceBool(getContext(), op, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), nullptr, target.buffer(), zShapeInfoH, target.specialBuffer(), zShapeInfoD, dims.data(), dims.size()); + } + synchronize("NDArray::reduceAlongDimension LongOps"); + + NDArray::registerSpecialUse({&target}, {this}); + } ////////////////////////////////////////////////////////////////////////// // This method sets value in linear buffer to position i -template -void NDArray::p(const Nd4jLong i, const T value) { + template + void NDArray::p(const Nd4jLong i, const T value) { - if (i >= lengthOf()) - throw std::invalid_argument("NDArray::p(i, value): input index is out of array length !"); + if (i >= lengthOf()) + throw std::invalid_argument("NDArray::p(i, value): input index is out of array length !"); - auto rp = getOffset(i); - const void *pV = reinterpret_cast(const_cast(&value)); + auto rp = getOffset(i); + const void *pV = reinterpret_cast(const_cast(&value)); - NDArray::preparePrimaryUse({this}, {}, true); - BUILD_SINGLE_PARTIAL_SELECTOR(this->dataType(), templatedSet<, T>(this->buffer(), rp, pV), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {}); -} + NDArray::preparePrimaryUse({this}, {}, true); + BUILD_SINGLE_PARTIAL_SELECTOR(this->dataType(), templatedSet<, T>(this->buffer(), rp, pV), LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {}); + } -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const double value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const float value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const float16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const bfloat16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint32_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint64_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const bool value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const double value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const float value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const float16 value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const bfloat16 value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int8_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint8_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint16_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint32_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const uint64_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const int16_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const bool value); ////////////////////////////////////////////////////////////////////////// // This method sets value in 2D matrix to position i, j -template -void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) { + template + void NDArray::p(const Nd4jLong i, const Nd4jLong j, const T value) { - if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) - throw std::invalid_argument("NDArray:pe(i,j, value): one of input indexes is out of array length or rank!=2 !"); + if (rankOf() != 2 || i >= shapeOf()[0] || j >= shapeOf()[1]) + throw std::invalid_argument("NDArray:pe(i,j, value): one of input indexes is out of array length or rank!=2 !"); - void *p = reinterpret_cast(const_cast(&value)); - auto xOffset = i * strideAt(0) + j * strideAt(1); + void *p = reinterpret_cast(const_cast(&value)); + auto xOffset = i * strideAt(0) + j * strideAt(1); - NDArray::preparePrimaryUse({this}, {}, true); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {}); -} -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bfloat16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint32_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint64_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bool value); + NDArray::preparePrimaryUse({this}, {}, true); + BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {}); + } + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const double value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const float16 value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bfloat16 value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int8_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint8_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint16_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint32_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const uint64_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const int16_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const bool value); ////////////////////////////////////////////////////////////////////////// // This method sets value in 3D matrix to position i,j,k -template -void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T value) { - //(*this)(i,j,k) = value; - if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) - throw std::invalid_argument("NDArray:pe(i,j,k, value): one of input indexes is out of array length or rank!=3 !"); + template + void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const T value) { + //(*this)(i,j,k) = value; + if (rankOf() != 3 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2]) + throw std::invalid_argument("NDArray:pe(i,j,k, value): one of input indexes is out of array length or rank!=3 !"); - void *p = reinterpret_cast(const_cast(&value)); - auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2); + void *p = reinterpret_cast(const_cast(&value)); + auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2); - NDArray::preparePrimaryUse({this}, {}, true); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {}); -} -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bfloat16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint32_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint64_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bool value); + NDArray::preparePrimaryUse({this}, {}, true); + BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {}); + } + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const double value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const float16 value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bfloat16 value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int8_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint8_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint16_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint32_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const uint64_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const int16_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const bool value); ////////////////////////////////////////////////////////////////////////// -template -void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const T value) { - //(*this)(i,j,k) = value; - if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2] || l >= shapeOf()[3]) - throw std::invalid_argument("NDArray::p(i,j,k,l, value): one of input indexes is out of array length or rank!=4 !"); + template + void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const T value) { + //(*this)(i,j,k) = value; + if (rankOf() != 4 || i >= shapeOf()[0] || j >= shapeOf()[1] || k >= shapeOf()[2] || l >= shapeOf()[3]) + throw std::invalid_argument("NDArray::p(i,j,k,l, value): one of input indexes is out of array length or rank!=4 !"); - void *p = reinterpret_cast(const_cast(&value)); - auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + l * strideAt(3); + void *p = reinterpret_cast(const_cast(&value)); + auto xOffset = i * strideAt(0) + j * strideAt(1) + k * strideAt(2) + l * strideAt(3); - NDArray::preparePrimaryUse({this}, {}, true); - BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {}); -} -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bfloat16 value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const Nd4jLong value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint8_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint32_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint64_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int16_t value); -template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bool value); + NDArray::preparePrimaryUse({this}, {}, true); + BUILD_SINGLE_PARTIAL_SELECTOR(dataType(), templatedSet<, T>(this->buffer(), xOffset, p), LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {}); + } + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const double value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const float16 value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bfloat16 value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const Nd4jLong value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int8_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint8_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint16_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint32_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const uint64_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const int16_t value); + template ND4J_EXPORT void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const bool value); //////////////////////////////////////////////////////////////////////// -void NDArray::p(const Nd4jLong i, const NDArray& scalar) { + void NDArray::p(const Nd4jLong i, const NDArray& scalar) { - if(scalar.lengthOf() != 1) - throw std::invalid_argument("NDArray::p method: input array must be scalar!"); - if (i >= _length) - throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !"); + if(scalar.lengthOf() != 1) + throw std::invalid_argument("NDArray::p method: input array must be scalar!"); + if (i >= _length) + throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !"); - NDArray::preparePrimaryUse({this}, {&scalar}, true); - auto rp = getOffset(i); - BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (buffer(), rp, scalar.dataType(), scalar.buffer()), LIBND4J_TYPES); - NDArray::registerPrimaryUse({this}, {&scalar}); -} + NDArray::preparePrimaryUse({this}, {&scalar}, true); + auto rp = getOffset(i); + BUILD_SINGLE_SELECTOR(scalar.dataType(), templatedSet, (buffer(), rp, scalar.dataType(), scalar.buffer()), LIBND4J_TYPES); + NDArray::registerPrimaryUse({this}, {&scalar}); + } //////////////////////////////////////////////////////////////////////// void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const NDArray& scalar) { @@ -4552,428 +4552,488 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) { } ////////////////////////////////////////////////////////////////////////// -void NDArray::addRowVector(const NDArray& row, NDArray& target) const { + void NDArray::addRowVector(const NDArray& row, NDArray& target) const { - if (isS()) - throw std::runtime_error("NDArray::addRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.lengthOf()) - throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && !(isR() && row.isR() && target.isR())) - throw std::invalid_argument("NDArray::addRowVector: wrong type of target array !"); + if (isS()) + throw std::runtime_error("NDArray::addRowVector: you can't use this method on String array!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.lengthOf()) + throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && !(isR() && row.isR() && target.isR())) + throw std::invalid_argument("NDArray::addRowVector: wrong type of target array !"); - int dimension = 1; + int dimension = 1; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - NDArray::prepareSpecialUse({&target}, {this, &row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &row}); -} + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); + } ////////////////////////////////////////////////////////////////////////// -void NDArray::subRowVector(const NDArray& row, NDArray& target) const { + void NDArray::subRowVector(const NDArray& row, NDArray& target) const { - if (isS()) - throw std::runtime_error("NDArray::addRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.lengthOf()) - throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && !(isR() && row.isR() && target.isR())) - throw std::invalid_argument("NDArray::addRowVector: wrong type of target array !"); + if (isS()) + throw std::runtime_error("NDArray::addRowVector: you can't use this method on String array!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.lengthOf()) + throw std::invalid_argument("NDArray::addRowVector: wrong arguments !"); + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType()) && !(isR() && row.isR() && target.isR())) + throw std::invalid_argument("NDArray::addRowVector: wrong type of target array !"); - int dimension = 1; + int dimension = 1; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - NDArray::prepareSpecialUse({&target}, {this, &row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), &dimension, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &row}); -} + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Subtract, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), &dimension, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); + } ////////////////////////////////////////////////////////////////////////// -void NDArray::mulRowVector(const NDArray &row, NDArray &target) const { + void NDArray::mulRowVector(const NDArray &row, NDArray &target) const { - if (isS()) - throw std::runtime_error("NDArray::mulRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.columns()) - throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) - throw std::invalid_argument("NDArray::mulRowVector: wrong type of target array !"); + if (isS()) + throw std::runtime_error("NDArray::mulRowVector: you can't use this method on String array!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.columns()) + throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) + throw std::invalid_argument("NDArray::mulRowVector: wrong type of target array !"); - int dimension = 1; + int dimension = 1; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - NDArray::prepareSpecialUse({&target}, {this, &row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &row}); -} + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); + } ////////////////////////////////////////////////////////////////////////// -void NDArray::divRowVector(const NDArray &row, NDArray &target) const { + void NDArray::divRowVector(const NDArray &row, NDArray &target) const { - if (isS()) - throw std::runtime_error("NDArray::divRowVector: you can't use this method on String array!"); - if (row.isB()) - throw std::runtime_error("NDArray::divRowVector: you can't divide by bool row!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.columns()) - throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) - throw std::invalid_argument("NDArray::divRowVector: wrong type of target array !"); + if (isS()) + throw std::runtime_error("NDArray::divRowVector: you can't use this method on String array!"); + if (row.isB()) + throw std::runtime_error("NDArray::divRowVector: you can't divide by bool row!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !row.isRowVector() || columns() != row.columns()) + throw std::invalid_argument("NDArray::divRowVector: wrong arguments !"); + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), row.dataType())) + throw std::invalid_argument("NDArray::divRowVector: wrong type of target array !"); - int dimension = 1; + int dimension = 1; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - NDArray::prepareSpecialUse({&target}, {this, &row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &row}); -} + NDArray::prepareSpecialUse({&target}, {this, &row}); + NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Divide, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &row}); + } ////////////////////////////////////////////////////////////////////////// // This method adds given row to all rows in this NDArray, this array becomes affected -void NDArray::addiRowVector(const NDArray& row) { + void NDArray::addiRowVector(const NDArray& row) { - if (isS()) - throw std::runtime_error("NDArray::addiRowVector: you can't use this method on String array!"); - if (rankOf() != 2 || !row.isRowVector() || columns() != row.lengthOf()) - throw std::invalid_argument("NDArray::addiRowVector: wrong arguments !"); + if (isS()) + throw std::runtime_error("NDArray::addiRowVector: you can't use this method on String array!"); + if (rankOf() != 2 || !row.isRowVector() || columns() != row.lengthOf()) + throw std::invalid_argument("NDArray::addiRowVector: wrong arguments !"); - int dimension = 1; + int dimension = 1; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - NDArray::prepareSpecialUse({this}, {&row}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {&row}); -} + NDArray::prepareSpecialUse({this}, {&row}); + NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), row.buffer(), row.shapeInfo(), row.specialBuffer(), row.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({this}, {&row}); + } ////////////////////////////////////////////////////////////////////////// -void NDArray::addColumnVector(const NDArray &column, NDArray &target) const { - if (isS()) - throw std::runtime_error("NDArray::addColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !column.isColumnVector() || rows() != column.lengthOf()) - throw std::invalid_argument("NDArray::addColumnVector: wrong arguments !"); - if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), column.dataType())) - throw std::invalid_argument("NDArray::addColumnVector: wrong type of target array !"); + void NDArray::addColumnVector(const NDArray &column, NDArray &target) const { + if (isS()) + throw std::runtime_error("NDArray::addColumnVector: you can't use this method on String array!"); + if (rankOf() != 2 || target.rankOf() != 2 || rows() != target.rows() || columns() != target.columns() || !column.isColumnVector() || rows() != column.lengthOf()) + throw std::invalid_argument("NDArray::addColumnVector: wrong arguments !"); + if(target.dataType() != DataTypeUtils::pickPairwiseResultType(dataType(), column.dataType())) + throw std::invalid_argument("NDArray::addColumnVector: wrong type of target array !"); - int dimension = 0; + int dimension = 0; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - NDArray::prepareSpecialUse({&target}, {this, &column}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({&target}, {this, &column}); -} + NDArray::prepareSpecialUse({&target}, {this, &column}); + NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), target.buffer(), target.shapeInfo(), target.specialBuffer(), target.specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({&target}, {this, &column}); + } ////////////////////////////////////////////////////////////////////////// // This method adds given column to all columns in this NDArray, this array becomes affected -void NDArray::addiColumnVector(const NDArray &column) { - if (isS()) - throw std::runtime_error("NDArray::addiColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) - throw std::invalid_argument("NDArray::addiColumnVector: wrong arguments !"); + void NDArray::addiColumnVector(const NDArray &column) { + if (isS()) + throw std::runtime_error("NDArray::addiColumnVector: you can't use this method on String array!"); + if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) + throw std::invalid_argument("NDArray::addiColumnVector: wrong arguments !"); - int dimension = 0; + int dimension = 0; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - NDArray::prepareSpecialUse({this}, {&column}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {&column}); -} + NDArray::prepareSpecialUse({this}, {&column}); + NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Add, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({this}, {&column}); + } ////////////////////////////////////////////////////////////////////////// // This method multiplies each column of this array by given argument-column, this array becomes affected -void NDArray::muliColumnVector(const NDArray& column) { - if (isS()) - throw std::runtime_error("NDArray::muliColumnVector: you can't use this method on String array!"); - if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) - throw std::invalid_argument("NDArray::muliColumnVector: wrong arguments !"); + void NDArray::muliColumnVector(const NDArray& column) { + if (isS()) + throw std::runtime_error("NDArray::muliColumnVector: you can't use this method on String array!"); + if (rankOf() != 2 || !column.isColumnVector() || rows() != column.lengthOf()) + throw std::invalid_argument("NDArray::muliColumnVector: wrong arguments !"); - int dimension = 0; + int dimension = 0; - auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(this->shapeInfo(), dimension); - NDArray::prepareSpecialUse({this}, {&column}); - NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); - NDArray::registerSpecialUse({this}, {&column}); -} + NDArray::prepareSpecialUse({this}, {&column}); + NativeOpExecutioner::execBroadcast(getContext(), sd::broadcast::Ops::Multiply, buffer(), shapeInfo(), specialBuffer(), specialShapeInfo(), column.buffer(), column.shapeInfo(), column.specialBuffer(), column.specialShapeInfo(), this->buffer(), this->shapeInfo(), this->specialBuffer(), this->specialShapeInfo(), nullptr, 1, packX.platformShapeInfo(), packX.platformOffsets(), nullptr, nullptr); + NDArray::registerSpecialUse({this}, {&column}); + } ////////////////////////////////////////////////////////////////////////// -template -void NDArray::templatedAssign(void *xBuffer, Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const { - if (xBuffer != nullptr && yBuffer != nullptr) - *(reinterpret_cast(xBuffer) + xOffset) = *(reinterpret_cast(yBuffer) + yOffset); -} -BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES); + template + void NDArray::templatedAssign(void *xBuffer, Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const { + if (xBuffer != nullptr && yBuffer != nullptr) + *(reinterpret_cast(xBuffer) + xOffset) = *(reinterpret_cast(yBuffer) + yOffset); + } + BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void NDArray::templatedAssign, (void *xBuffer, const Nd4jLong xOffset, const void *yBuffer, const Nd4jLong yOffset) const, LIBND4J_TYPES); ////////////////////////////////////////////////////////////////////////// -bool NDArray::permutei(const int* dimensions, const int rank) { + bool NDArray::permutei(const int* dimensions, const int rank) { - auto shapeInfo = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); - setShapeInfo(shapeInfo); + auto shapeInfo = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); + setShapeInfo(shapeInfo); - return true; -} + return true; + } ////////////////////////////////////////////////////////////////////////// -bool NDArray::permutei(const Nd4jLong* dimensions, const int rank) { + bool NDArray::permutei(const Nd4jLong* dimensions, const int rank) { - auto shapeInfo = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); - setShapeInfo(shapeInfo); + auto shapeInfo = ShapeUtils::evalPermShapeInfo(dimensions, rank, *this, getContext()->getWorkspace()); + setShapeInfo(shapeInfo); - return true; -} + return true; + } //////////////////////////////////////////////////////////////////////// -ResultSet NDArray::multipleTensorsAlongDimension(const std::vector &indices, const std::vector &dimensions) const { - ResultSet result; + ResultSet NDArray::multipleTensorsAlongDimension(const std::vector &indices, const std::vector &dimensions) const { + ResultSet result; + + if (indices.size() == 0) + return result; + + auto pack = ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), const_cast(dimensions.data()), dimensions.size()); + + auto tadLength = shape::length(pack.primaryShapeInfo()); + auto numTads = lengthOf() / tadLength; + + for (auto idx: indices) { + if (idx >= numTads) { + nd4j_printf("NDArray::multipleTensorsAlongDimension: index %i is higher then number of TADs: %i\n", idx, numTads); + throw std::runtime_error("Bad index"); + } + + auto array = new NDArray(getDataBuffer(), ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + bufferOffset()); + result.push_back(array); + } - if (indices.size() == 0) return result; - - auto pack = ConstantTadHelper::getInstance().tadForDimensions(shapeInfo(), const_cast(dimensions.data()), dimensions.size()); - - auto tadLength = shape::length(pack.primaryShapeInfo()); - auto numTads = lengthOf() / tadLength; - - for (auto idx: indices) { - if (idx >= numTads) { - nd4j_printf("NDArray::multipleTensorsAlongDimension: index %i is higher then number of TADs: %i\n", idx, numTads); - throw std::runtime_error("Bad index"); - } - - auto array = new NDArray(getDataBuffer(), ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + bufferOffset()); - result.push_back(array); } - return result; -} - //////////////////////////////////////////////////////////////////////// -ResultSet NDArray::allTensorsAlongDimension(const std::initializer_list& dimensions) const { - return allTensorsAlongDimension(std::vector(dimensions)); -} - -//////////////////////////////////////////////////////////////////////// -ResultSet NDArray::allExamples() const { - std::vector dimensions(rankOf() - 1); - for (int e = 1; e < rankOf(); e++) - dimensions[e-1] = e; - - return allTensorsAlongDimension(dimensions); -} - -//////////////////////////////////////////////////////////////////////// -Nd4jLong NDArray::getOffset(const Nd4jLong i) const { - - if (i >= lengthOf()) - throw std::invalid_argument("NDArray::getOffset: input index is out of array length !"); - - return shape::getIndexOffset(i, _shapeInfo); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::like() { - - return NDArray(shapeInfo(), this->dataType(), false, getContext()); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::ulike() const{ - - return NDArray(this, false, getContext()); -} - -//////////////////////////////////////////////////////////////////////// -NDArray NDArray::diagonal(const char type) const { - - if (isS()) - throw std::runtime_error("NDArray::diagonal: you can't use this method on String array!"); - - const char order = ordering(); - const int rank = rankOf(); - Nd4jLong *outShapeInfo; - ALLOCATE(outShapeInfo, getContext()->getWorkspace(), 8, Nd4jLong); - outShapeInfo[0] = 2; - outShapeInfo[5] = 0; - - if(isVector() || isScalar()) { - - outShapeInfo[1] = outShapeInfo[2] = outShapeInfo[3] = outShapeInfo[4] = 1; - outShapeInfo[6] = 1; - outShapeInfo[7] = (int)order; + ResultSet NDArray::allTensorsAlongDimension(const std::initializer_list& dimensions) const { + return allTensorsAlongDimension(std::vector(dimensions)); } - else { - int diagSize = 100000000; - Nd4jLong indices[MAX_RANK]; +//////////////////////////////////////////////////////////////////////// + ResultSet NDArray::allExamples() const { + std::vector dimensions(rankOf() - 1); + for (int e = 1; e < rankOf(); e++) + dimensions[e-1] = e; - for(int i = 0; i < rank; ++i) { - if(diagSize > shapeOf()[i]) - diagSize = shapeOf()[i]; - indices[i] = 1; - } + return allTensorsAlongDimension(dimensions); + } - auto step = shape::getOffset(shapeInfo(), indices); +//////////////////////////////////////////////////////////////////////// + Nd4jLong NDArray::getOffset(const Nd4jLong i) const { - if(type == 'c') { - outShapeInfo[1] = diagSize; - outShapeInfo[2] = 1; + if (i >= lengthOf()) + throw std::invalid_argument("NDArray::getOffset: input index is out of array length !"); + + return shape::getIndexOffset(i, _shapeInfo); + } + +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::like() { + + return NDArray(shapeInfo(), this->dataType(), false, getContext()); + } + +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::ulike() const{ + + return NDArray(this, false, getContext()); + } + +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::diagonal(const char type) const { + + if (isS()) + throw std::runtime_error("NDArray::diagonal: you can't use this method on String array!"); + + const char order = ordering(); + const int rank = rankOf(); + Nd4jLong *outShapeInfo; + ALLOCATE(outShapeInfo, getContext()->getWorkspace(), 8, Nd4jLong); + outShapeInfo[0] = 2; + outShapeInfo[5] = 0; + + if(isVector() || isScalar()) { + + outShapeInfo[1] = outShapeInfo[2] = outShapeInfo[3] = outShapeInfo[4] = 1; + outShapeInfo[6] = 1; + outShapeInfo[7] = (int)order; } else { - outShapeInfo[1] = 1; - outShapeInfo[2] = diagSize; + + int diagSize = 100000000; + Nd4jLong indices[MAX_RANK]; + + for(int i = 0; i < rank; ++i) { + if(diagSize > shapeOf()[i]) + diagSize = shapeOf()[i]; + indices[i] = 1; + } + + auto step = shape::getOffset(shapeInfo(), indices); + + if(type == 'c') { + outShapeInfo[1] = diagSize; + outShapeInfo[2] = 1; + } + else { + outShapeInfo[1] = 1; + outShapeInfo[2] = diagSize; + } + shape::updateStrides(outShapeInfo, order); + + outShapeInfo[3] *= step; + outShapeInfo[4] *= step; + outShapeInfo[6] = 0; } - shape::updateStrides(outShapeInfo, order); - outShapeInfo[3] *= step; - outShapeInfo[4] *= step; - outShapeInfo[6] = 0; + ArrayOptions::setDataType(outShapeInfo, this->dataType()); + + NDArray result(_buffer, ShapeDescriptor(outShapeInfo), getContext(), bufferOffset()); + + RELEASE(outShapeInfo, getContext()->getWorkspace()); + + return result; } - ArrayOptions::setDataType(outShapeInfo, this->dataType()); - - NDArray result(_buffer, ShapeDescriptor(outShapeInfo), getContext(), bufferOffset()); - - RELEASE(outShapeInfo, getContext()->getWorkspace()); - - return result; -} - //////////////////////////////////////////////////////////////////////// -ResultSet NDArray::allTensorsAlongDimension(const std::vector &dimensions) const { + ResultSet NDArray::allTensorsAlongDimension(const std::vector &dimensions) const { - ResultSet result; + ResultSet result; + + if(dimensions.size() == 0) + return result; + else if(dimensions.back() == rankOf()) { + auto array = new NDArray(_buffer, this->shapeInfo(), getContext(),bufferOffset()); + array->_isView = true; + result.push_back(array); + nd4j_debug("NDArray::allTensorsAlongDimension: Dimensions were equal %d with this rank of %d\n",dimensions.back(),rankOf()); + return result; + } + + + if(dimensions.back() >= rankOf()) { + nd4j_debug("Dimensions failure %d and rank %d\n",dimensions.back(),rankOf()); + throw std::runtime_error( + "NDArray::allTensorsAlongDimension static function: all input dimensions must be smaller than rank of input array !"); + } + + auto pack = ConstantTadHelper::getInstance().tadForDimensions(_shapeInfo, const_cast(dimensions.data()), dimensions.size()); + auto numTads = pack.numberOfTads(); + + for (Nd4jLong idx = 0; idx < numTads; idx++ ) { + auto array = new NDArray(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + bufferOffset()); + array->_isView = true; + result.push_back(array); + } - if(dimensions.size() == 0) return result; - - if(dimensions.back() >= rankOf()) - throw std::runtime_error("NDArray::allTensorsAlongDimension static function: all input dimensions must be smaller than rank of input array !"); - - - auto pack = ConstantTadHelper::getInstance().tadForDimensions(_shapeInfo, const_cast(dimensions.data()), dimensions.size()); - auto numTads = pack.numberOfTads(); - - for (Nd4jLong idx = 0; idx < numTads; idx++ ) { - auto array = new NDArray(_buffer, ShapeDescriptor(pack.primaryShapeInfo()), getContext(), pack.primaryOffsets()[idx] + bufferOffset()); - array->_isView = true; - result.push_back(array); } - return result; -} - //////////////////////////////////////////////////////////////////////// // operator returns sub-array with buffer pointing at this->_buffer + certain offset -NDArray NDArray::operator()(const std::vector& idx, const bool keepUnitiesInShape, const bool isStrided) const { + NDArray NDArray::operator()(const std::vector& idx, const bool keepUnitiesInShape, const bool isStrided) const { - if(isEmpty()) - throw std::invalid_argument("NDArray::operator(sub-arrays): array is empty !"); + if(isEmpty()) + throw std::invalid_argument("NDArray::operator(sub-arrays): array is empty !"); - // Nd4jLong *outShapeInfo = nullptr; - // ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo), Nd4jLong); + // Nd4jLong *outShapeInfo = nullptr; + // ALLOCATE(outShapeInfo, workspace, shape::shapeInfoLength(inShapeInfo), Nd4jLong); - int numOfUntiesInSubArrShape = 0; + int numOfUntiesInSubArrShape = 0; - Nd4jLong* subArrShapeInfo = nullptr; + Nd4jLong* subArrShapeInfo = nullptr; - if(!keepUnitiesInShape) { + if(!keepUnitiesInShape) { - int n(isStrided ? 3 : 2), first, last; + int n(isStrided ? 3 : 2), first, last; - // calculate the number of unities in shape - for (uint d = 0; d < rankOf(); ++d) { + // calculate the number of unities in shape + for (uint d = 0; d < rankOf(); ++d) { - if (idx[n * d] != idx[n * d + 1]) { + if (idx[n * d] != idx[n * d + 1]) { - first = idx[n * d] >= 0 ? idx[n * d] : idx[n * d] + sizeAt(d) + 1; - last = idx[n * d + 1] >= 0 ? idx[n * d + 1] : idx[n * d + 1] + sizeAt(d) + 1; - if(last - first == 1) - ++numOfUntiesInSubArrShape; + first = idx[n * d] >= 0 ? idx[n * d] : idx[n * d] + sizeAt(d) + 1; + last = idx[n * d + 1] >= 0 ? idx[n * d + 1] : idx[n * d + 1] + sizeAt(d) + 1; + if(last - first == 1) + ++numOfUntiesInSubArrShape; + } } } + + ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), shape::shapeInfoLength(rankOf() - numOfUntiesInSubArrShape), Nd4jLong); + + Nd4jLong offset; + + shape::calcSubArrShapeInfoAndOffset(idx.data(), shapeInfo(), subArrShapeInfo, offset, keepUnitiesInShape, isStrided, numOfUntiesInSubArrShape); + + NDArray result(_buffer, ShapeDescriptor(subArrShapeInfo), getContext(), offset + bufferOffset()); + result._isView = true; + + RELEASE(subArrShapeInfo, getContext()->getWorkspace()); + + return result; } - ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), shape::shapeInfoLength(rankOf() - numOfUntiesInSubArrShape), Nd4jLong); +//////////////////////////////////////////////////////////////////////// + NDArray NDArray::operator()(const Nd4jLong subArrIdx, const std::vector& dimsToExclude, bool keepUnitiesInShape) const { - Nd4jLong offset; + std::vector idxRanges(2 * rankOf()); - shape::calcSubArrShapeInfoAndOffset(idx.data(), shapeInfo(), subArrShapeInfo, offset, keepUnitiesInShape, isStrided, numOfUntiesInSubArrShape); + const auto rank = rankOf(); + const auto subArrRank = static_cast(dimsToExclude.size()); - NDArray result(_buffer, ShapeDescriptor(subArrShapeInfo), getContext(), offset + bufferOffset()); - result._isView = true; + if(subArrRank > rank) + throw std::invalid_argument("NDArray::operator(const Nd4jLong subArrIdx, const std::vector& dimsToExclude, bool keepUnitiesInShape): static method: dimsToExclude is empty or has size > rank of array !"); - RELEASE(subArrShapeInfo, getContext()->getWorkspace()); + memset(idxRanges.data(), 0, 2 * rank * sizeof(Nd4jLong)); - return result; -} + // subArrRank == 0 means whole array, idxRanges should contain zeros only + + if(subArrRank != 0) { + + std::vector shapeOfSubArr(subArrRank), indexes(subArrRank); + for(int i = 0; i < subArrRank; ++i) + shapeOfSubArr[i] = sizeAt(dimsToExclude[i]); + + shape::index2coords(subArrIdx, subArrRank, shapeOfSubArr.data(), indexes.data()); + + for(int i = 0; i < subArrRank; ++i) { + int currIdx = 2 * dimsToExclude[i]; + idxRanges[currIdx] = indexes[i]; + idxRanges[currIdx + 1] = indexes[i] + 1; + } + } + + return (*this)(idxRanges, keepUnitiesInShape); + } //////////////////////////////////////////////////////////////////////// -NDArray NDArray::operator()(const Nd4jLong subArrIdx, const std::vector& dimsToExclude, bool keepUnitiesInShape) const { + void NDArray::getSubArrShapeAndOffsets(const std::vector& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape) const { - std::vector idxRanges(2 * rankOf()); + if(isEmpty()) + throw std::invalid_argument("NDArray::getSubArrShapeAndOffsets: array is empty !"); - const auto rank = rankOf(); - const auto subArrRank = static_cast(dimsToExclude.size()); + const int rank = rankOf(); + const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) ? rank : rank - dimsToExclude.size(); + const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude); - if(subArrRank > rank) - throw std::invalid_argument("NDArray::operator(const Nd4jLong subArrIdx, const std::vector& dimsToExclude, bool keepUnitiesInShape): static method: dimsToExclude is empty or has size > rank of array !"); + // allocate memory + ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), shape::shapeInfoLength(subArrRank), Nd4jLong); + ALLOCATE(subArrOffsets, getContext()->getWorkspace(), numOfSubArrs, Nd4jLong); - memset(idxRanges.data(), 0, 2 * rank * sizeof(Nd4jLong)); + shape::calcSubArrsShapeInfoAndOffsets(_shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo, subArrOffsets, keepUnitiesInShape); + } - // subArrRank == 0 means whole array, idxRanges should contain zeros only +////////////////////////////////////////////////////////////////////////// + void NDArray::setShapeInfo(const Nd4jLong *shapeInfo) { - if(subArrRank != 0) { + if (shapeInfo != nullptr) { - std::vector shapeOfSubArr(subArrRank), indexes(subArrRank); - for(int i = 0; i < subArrRank; ++i) - shapeOfSubArr[i] = sizeAt(dimsToExclude[i]); + ShapeDescriptor descriptor(shapeInfo); + auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); - shape::index2coords(subArrIdx, subArrRank, shapeOfSubArr.data(), indexes.data()); + _shapeInfo = shapeBuffer.primary(); +#ifdef __CUDABLAS__ + _shapeInfoD = shapeBuffer.special(); +#endif - for(int i = 0; i < subArrRank; ++i) { - int currIdx = 2 * dimsToExclude[i]; - idxRanges[currIdx] = indexes[i]; - idxRanges[currIdx + 1] = indexes[i] + 1; + if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); + + _dataType = ArrayOptions::dataType(_shapeInfo); + } + else { + _dataType = sd::DataType::INHERIT; + _shapeInfoD = _shapeInfo = nullptr; } } - return (*this)(idxRanges, keepUnitiesInShape); -} - //////////////////////////////////////////////////////////////////////// -void NDArray::getSubArrShapeAndOffsets(const std::vector& dimsToExclude, Nd4jLong* &subArrShapeInfo, Nd4jLong* &subArrOffsets, bool keepUnitiesInShape) const { + void NDArray::setShapeInfo(const Nd4jLong *shapeInfo, const sd::DataType dtype) { - if(isEmpty()) - throw std::invalid_argument("NDArray::getSubArrShapeAndOffsets: array is empty !"); + if (shapeInfo != nullptr) { - const int rank = rankOf(); - const int subArrRank = (rank == dimsToExclude.size() || keepUnitiesInShape) ? rank : rank - dimsToExclude.size(); - const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(_shapeInfo, dimsToExclude); + Nd4jLong* shapeInfoTemp = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dtype, true, getContext()->getWorkspace()); + ShapeDescriptor descriptor(shapeInfoTemp); + auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); - // allocate memory - ALLOCATE(subArrShapeInfo, getContext()->getWorkspace(), shape::shapeInfoLength(subArrRank), Nd4jLong); - ALLOCATE(subArrOffsets, getContext()->getWorkspace(), numOfSubArrs, Nd4jLong); + _shapeInfo = shapeBuffer.primary(); +#ifdef __CUDABLAS__ + _shapeInfoD = shapeBuffer.special(); +#endif - shape::calcSubArrsShapeInfoAndOffsets(_shapeInfo, numOfSubArrs, dimsToExclude.size(), dimsToExclude.data(), subArrShapeInfo, subArrOffsets, keepUnitiesInShape); -} + if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) + _length = 0; + else + _length = shape::length(_shapeInfo); + + _dataType = dtype; + } + else { + _dataType = sd::DataType::INHERIT; + _shapeInfoD = _shapeInfo = nullptr; + } + } ////////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(const Nd4jLong *shapeInfo) { + void NDArray::setShapeInfo(const ShapeDescriptor& descriptor) { - if (shapeInfo != nullptr) { - - ShapeDescriptor descriptor(shapeInfo); - auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); + auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(const_cast(descriptor)); _shapeInfo = shapeBuffer.primary(); - #ifdef __CUDABLAS__ - _shapeInfoD = shapeBuffer.special(); - #endif +#ifdef __CUDABLAS__ + _shapeInfoD = shapeBuffer.special(); +#endif if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) _length = 0; @@ -4982,605 +5042,555 @@ void NDArray::setShapeInfo(const Nd4jLong *shapeInfo) { _dataType = ArrayOptions::dataType(_shapeInfo); } - else { - _dataType = sd::DataType::INHERIT; - _shapeInfoD = _shapeInfo = nullptr; - } -} -//////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(const Nd4jLong *shapeInfo, const sd::DataType dtype) { - - if (shapeInfo != nullptr) { - - Nd4jLong* shapeInfoTemp = ShapeBuilders::copyShapeInfoAndType(shapeInfo, dtype, true, getContext()->getWorkspace()); - ShapeDescriptor descriptor(shapeInfoTemp); - auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); +////////////////////////////////////////////////////////////////////////// + void NDArray::setShapeInfo(const ConstantShapeBuffer& shapeBuffer) { _shapeInfo = shapeBuffer.primary(); - #ifdef __CUDABLAS__ - _shapeInfoD = shapeBuffer.special(); - #endif +#ifdef __CUDABLAS__ + _shapeInfoD = shapeBuffer.special(); +#endif if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) _length = 0; else _length = shape::length(_shapeInfo); - _dataType = dtype; + _dataType = ArrayOptions::dataType(_shapeInfo); } - else { - _dataType = sd::DataType::INHERIT; - _shapeInfoD = _shapeInfo = nullptr; - } -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(const ShapeDescriptor& descriptor) { - - auto shapeBuffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(const_cast(descriptor)); - - _shapeInfo = shapeBuffer.primary(); - #ifdef __CUDABLAS__ - _shapeInfoD = shapeBuffer.special(); - #endif - - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - - _dataType = ArrayOptions::dataType(_shapeInfo); -} - -////////////////////////////////////////////////////////////////////////// -void NDArray::setShapeInfo(const ConstantShapeBuffer& shapeBuffer) { - - _shapeInfo = shapeBuffer.primary(); - #ifdef __CUDABLAS__ - _shapeInfoD = shapeBuffer.special(); - #endif - - if(ArrayOptions::arrayType(_shapeInfo) == ArrayType::EMPTY) - _length = 0; - else - _length = shape::length(_shapeInfo); - - _dataType = ArrayOptions::dataType(_shapeInfo); -} /////////////////////////////////////////////////////////////////////// // addition operator array + scalar -template -NDArray operator+(NDArray&& arr, const T& scalar) { + template + NDArray operator+(NDArray&& arr, const T& scalar) { - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(arr + scalar); // arr is lvalue inside function body + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(arr + scalar); // arr is lvalue inside function body - if (arr.isS()) - throw std::runtime_error("operator+(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) - throw std::runtime_error("operator+(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error("operator+(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error("operator+(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Add, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Add, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - return std::move(arr); -} -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const double& scalar); -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const float& scalar); -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator+(NDArray&& arr, const int& scalar); + return std::move(arr); + } + template ND4J_EXPORT NDArray operator+(NDArray&& arr, const double& scalar); + template ND4J_EXPORT NDArray operator+(NDArray&& arr, const float& scalar); + template ND4J_EXPORT NDArray operator+(NDArray&& arr, const float16& scalar); + template ND4J_EXPORT NDArray operator+(NDArray&& arr, const bfloat16& scalar); + template ND4J_EXPORT NDArray operator+(NDArray&& arr, const int& scalar); //////////////////////////////////////////////////////////////////////// -template -NDArray operator+(const NDArray& arr, const T& scalar) { + template + NDArray operator+(const NDArray& arr, const T& scalar) { - if (arr.isS()) - throw std::runtime_error("operator+(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error("operator+(const NDArray& arr, const T& scalar): you can't use this method on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Add, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Add, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - return result; -} -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const double& scalar); -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const float& scalar); -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator+(const NDArray& arr, const int& scalar); + return result; + } + template ND4J_EXPORT NDArray operator+(const NDArray& arr, const double& scalar); + template ND4J_EXPORT NDArray operator+(const NDArray& arr, const float& scalar); + template ND4J_EXPORT NDArray operator+(const NDArray& arr, const float16& scalar); + template ND4J_EXPORT NDArray operator+(const NDArray& arr, const bfloat16& scalar); + template ND4J_EXPORT NDArray operator+(const NDArray& arr, const int& scalar); //////////////////////////////////////////////////////////////////////// -template -NDArray operator+(const T& scalar, NDArray&& arr) { - return std::move(arr) + scalar; -} -template ND4J_EXPORT NDArray operator+(const double& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator+(const float& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator+(const float16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator+(const bfloat16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator+(const int& scalar, NDArray&& arr); + template + NDArray operator+(const T& scalar, NDArray&& arr) { + return std::move(arr) + scalar; + } + template ND4J_EXPORT NDArray operator+(const double& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator+(const float& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator+(const float16& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator+(const bfloat16& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator+(const int& scalar, NDArray&& arr); //////////////////////////////////////////////////////////////////////// -template -NDArray operator+(const T& scalar, const NDArray& arr) { - return arr + scalar; -} -template ND4J_EXPORT NDArray operator+(const double& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator+(const float& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator+(const int& scalar, const NDArray& arr); + template + NDArray operator+(const T& scalar, const NDArray& arr) { + return arr + scalar; + } + template ND4J_EXPORT NDArray operator+(const double& scalar, const NDArray& arr); + template ND4J_EXPORT NDArray operator+(const float& scalar, const NDArray& arr); + template ND4J_EXPORT NDArray operator+(const int& scalar, const NDArray& arr); /////////////////////////////////////////////////////////////////////// // addition operator array - scalar -template -NDArray operator-(NDArray&& arr, const T& scalar) { + template + NDArray operator-(NDArray&& arr, const T& scalar) { - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(arr - scalar); // arr is lvalue inside function body + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(arr - scalar); // arr is lvalue inside function body - if (arr.isS()) - throw std::runtime_error("operator-(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) - throw std::runtime_error("operator-(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error("operator-(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error("operator-(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Subtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Subtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - return std::move(arr); -} -template ND4J_EXPORT NDArray operator-(NDArray&& arr, const double& scalar); -template ND4J_EXPORT NDArray operator-(NDArray&& arr, const float& scalar); + return std::move(arr); + } + template ND4J_EXPORT NDArray operator-(NDArray&& arr, const double& scalar); + template ND4J_EXPORT NDArray operator-(NDArray&& arr, const float& scalar); //////////////////////////////////////////////////////////////////////// -template -NDArray operator-(const NDArray& arr, const T& scalar) { + template + NDArray operator-(const NDArray& arr, const T& scalar) { - if (arr.isS()) - throw std::runtime_error("operator-(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error("operator-(const NDArray& arr, const T& scalar): you can't use this method on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Subtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Subtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - return result; -} -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const double& scalar); -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const float& scalar); -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator-(const NDArray& arr, const int& scalar); + return result; + } + template ND4J_EXPORT NDArray operator-(const NDArray& arr, const double& scalar); + template ND4J_EXPORT NDArray operator-(const NDArray& arr, const float& scalar); + template ND4J_EXPORT NDArray operator-(const NDArray& arr, const float16& scalar); + template ND4J_EXPORT NDArray operator-(const NDArray& arr, const bfloat16& scalar); + template ND4J_EXPORT NDArray operator-(const NDArray& arr, const int& scalar); //////////////////////////////////////////////////////////////////////// -template -NDArray operator-(const T& scalar, NDArray&& arr) { + template + NDArray operator-(const T& scalar, NDArray&& arr) { - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(scalar - arr); // arr is lvalue inside function body + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(scalar - arr); // arr is lvalue inside function body - if (arr.isS()) - throw std::runtime_error("operator-(const T& scalar, NDArray&& arr): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error("operator-(const T& scalar, NDArray&& arr): you can't use this method on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseSubtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseSubtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - return std::move(arr); + return std::move(arr); -} -template ND4J_EXPORT NDArray operator-(const double& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator-(const float& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator-(const float16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator-(const bfloat16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator-(const int& scalar, NDArray&& arr); + } + template ND4J_EXPORT NDArray operator-(const double& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator-(const float& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator-(const float16& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator-(const bfloat16& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator-(const int& scalar, NDArray&& arr); //////////////////////////////////////////////////////////////////////// -template -NDArray operator-(const T& scalar, const NDArray& arr) { + template + NDArray operator-(const T& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("operator-(const T& scalar, const NDArray& arr): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error("operator-(const T& scalar, const NDArray& arr): you can't use this method on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseSubtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseSubtract, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - return result; -} -template ND4J_EXPORT NDArray operator-(const double& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator-(const float& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator-(const int& scalar, const NDArray& arr); + return result; + } + template ND4J_EXPORT NDArray operator-(const double& scalar, const NDArray& arr); + template ND4J_EXPORT NDArray operator-(const float& scalar, const NDArray& arr); + template ND4J_EXPORT NDArray operator-(const int& scalar, const NDArray& arr); /////////////////////////////////////////////////////////////////////// // addition operator array + scalar -template -NDArray operator*(NDArray&& arr, const T& scalar) { + template + NDArray operator*(NDArray&& arr, const T& scalar) { - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(arr * scalar); // arr is lvalue inside function body + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(arr * scalar); // arr is lvalue inside function body - if (arr.isS()) - throw std::runtime_error("operator*(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) - throw std::runtime_error("operator*(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error("operator*(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error("operator*(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Multiply, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Multiply, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - return std::move(arr); -} -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const double& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const float& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const int& scalar); -template ND4J_EXPORT NDArray operator*(NDArray&& arr, const long long& scalar); + return std::move(arr); + } + template ND4J_EXPORT NDArray operator*(NDArray&& arr, const double& scalar); + template ND4J_EXPORT NDArray operator*(NDArray&& arr, const float& scalar); + template ND4J_EXPORT NDArray operator*(NDArray&& arr, const float16& scalar); + template ND4J_EXPORT NDArray operator*(NDArray&& arr, const bfloat16& scalar); + template ND4J_EXPORT NDArray operator*(NDArray&& arr, const int& scalar); + template ND4J_EXPORT NDArray operator*(NDArray&& arr, const long long& scalar); //////////////////////////////////////////////////////////////////////// -template -NDArray operator*(const NDArray& arr, const T& scalar) { + template + NDArray operator*(const NDArray& arr, const T& scalar) { - if (arr.isS()) - throw std::runtime_error("operator*(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error("operator*(const NDArray& arr, const T& scalar): you can't use this method on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Multiply, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Multiply, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - return result; -} + return result; + } -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const double& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const float& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const int& scalar); -template ND4J_EXPORT NDArray operator*(const NDArray& arr, const long long& scalar); + template ND4J_EXPORT NDArray operator*(const NDArray& arr, const double& scalar); + template ND4J_EXPORT NDArray operator*(const NDArray& arr, const float& scalar); + template ND4J_EXPORT NDArray operator*(const NDArray& arr, const float16& scalar); + template ND4J_EXPORT NDArray operator*(const NDArray& arr, const bfloat16& scalar); + template ND4J_EXPORT NDArray operator*(const NDArray& arr, const int& scalar); + template ND4J_EXPORT NDArray operator*(const NDArray& arr, const long long& scalar); //////////////////////////////////////////////////////////////////////// -template -NDArray operator*(const T& scalar, NDArray&& arr) { - return std::move(arr) * scalar; -} -template ND4J_EXPORT NDArray operator*(const double& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const float& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const float16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const bfloat16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const int& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator*(const long long& scalar, NDArray&& arr); + template + NDArray operator*(const T& scalar, NDArray&& arr) { + return std::move(arr) * scalar; + } + template ND4J_EXPORT NDArray operator*(const double& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator*(const float& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator*(const float16& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator*(const bfloat16& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator*(const int& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator*(const long long& scalar, NDArray&& arr); //////////////////////////////////////////////////////////////////////// -template -NDArray operator*(const T& scalar, const NDArray& arr) { - return arr * scalar; -} -template ND4J_EXPORT NDArray operator*(const double& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const float& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const float16& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const bfloat16& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const int& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator*(const long long& scalar, const NDArray& arr); + template + NDArray operator*(const T& scalar, const NDArray& arr) { + return arr * scalar; + } + template ND4J_EXPORT NDArray operator*(const double& scalar, const NDArray& arr); + template ND4J_EXPORT NDArray operator*(const float& scalar, const NDArray& arr); + template ND4J_EXPORT NDArray operator*(const float16& scalar, const NDArray& arr); + template ND4J_EXPORT NDArray operator*(const bfloat16& scalar, const NDArray& arr); + template ND4J_EXPORT NDArray operator*(const int& scalar, const NDArray& arr); + template ND4J_EXPORT NDArray operator*(const long long& scalar, const NDArray& arr); /////////////////////////////////////////////////////////////////////// -template -NDArray operator/(NDArray&& arr, const T& scalar) { + template + NDArray operator/(NDArray&& arr, const T& scalar) { - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(arr / scalar); // arr is lvalue inside function body + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(arr / scalar); // arr is lvalue inside function body - if (arr.isS()) - throw std::runtime_error("operator/(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) - throw std::runtime_error("operator/(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error("operator/(NDArray&& arr, const T& scalar): you can't use this method on String array!"); + if (arr.dataType() != DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT())) + throw std::runtime_error("operator/(NDArray&& arr, const T& scalar): you can't use this method on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Divide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Divide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - return std::move(arr); -} -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const double& scalar); -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const float& scalar); -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator/(NDArray&& arr, const long long& scalar); + return std::move(arr); + } + template ND4J_EXPORT NDArray operator/(NDArray&& arr, const double& scalar); + template ND4J_EXPORT NDArray operator/(NDArray&& arr, const float& scalar); + template ND4J_EXPORT NDArray operator/(NDArray&& arr, const float16& scalar); + template ND4J_EXPORT NDArray operator/(NDArray&& arr, const bfloat16& scalar); + template ND4J_EXPORT NDArray operator/(NDArray&& arr, const long long& scalar); //////////////////////////////////////////////////////////////////////// -template -NDArray operator/(const NDArray& arr, const T& scalar) { + template + NDArray operator/(const NDArray& arr, const T& scalar) { - if (arr.isS()) - throw std::runtime_error("operator/(const NDArray& arr, const T& scalar): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error("operator/(const NDArray& arr, const T& scalar): you can't use this method on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Divide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::Divide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - return result; -} -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const double& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const float& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const float16& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const bfloat16& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const int& scalar); -template ND4J_EXPORT NDArray operator/(const NDArray& arr, const long long& scalar); + return result; + } + template ND4J_EXPORT NDArray operator/(const NDArray& arr, const double& scalar); + template ND4J_EXPORT NDArray operator/(const NDArray& arr, const float& scalar); + template ND4J_EXPORT NDArray operator/(const NDArray& arr, const float16& scalar); + template ND4J_EXPORT NDArray operator/(const NDArray& arr, const bfloat16& scalar); + template ND4J_EXPORT NDArray operator/(const NDArray& arr, const int& scalar); + template ND4J_EXPORT NDArray operator/(const NDArray& arr, const long long& scalar); //////////////////////////////////////////////////////////////////////// -template -NDArray operator/(const T& scalar, NDArray&& arr) { + template + NDArray operator/(const T& scalar, NDArray&& arr) { - if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays - return std::move(scalar / arr); // arr is lvalue inside function body + if(arr.isView()) // do not use resources of arrays which use buffers of other original arrays + return std::move(scalar / arr); // arr is lvalue inside function body - if (arr.isS()) - throw std::runtime_error("operator/(const T& scalar, NDArray&& arr): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error("operator/(const T& scalar, NDArray&& arr): you can't use this method on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseDivide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); + NDArray::prepareSpecialUse({&arr}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseDivide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&arr}, {&arr, &tmp}); - return std::move(arr); + return std::move(arr); -} -template ND4J_EXPORT NDArray operator/(const double& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator/(const float& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator/(const float16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator/(const bfloat16& scalar, NDArray&& arr); -template ND4J_EXPORT NDArray operator/(const int& scalar, NDArray&& arr); + } + template ND4J_EXPORT NDArray operator/(const double& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator/(const float& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator/(const float16& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator/(const bfloat16& scalar, NDArray&& arr); + template ND4J_EXPORT NDArray operator/(const int& scalar, NDArray&& arr); //////////////////////////////////////////////////////////////////////// -template -NDArray operator/(const T& scalar, const NDArray& arr) { + template + NDArray operator/(const T& scalar, const NDArray& arr) { - if (arr.isS()) - throw std::runtime_error("operator/(const T& scalar, const NDArray& arr): you can't use this method on String array!"); + if (arr.isS()) + throw std::runtime_error("operator/(const T& scalar, const NDArray& arr): you can't use this method on String array!"); - auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); - NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); + auto tmp = NDArrayFactory::create(arr.dataType(), scalar, arr.getContext()); + NDArray result(arr.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr.dataType(), DataTypeUtils::fromT()), false, arr.getContext()); - NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); - NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseDivide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({&result}, {&arr, &tmp}); + NDArray::prepareSpecialUse({&result}, {&arr, &tmp}); + NativeOpExecutioner::execScalar(arr.getContext(), sd::scalar::ReverseDivide, arr.buffer(), arr.shapeInfo(), arr.specialBuffer(), arr.specialShapeInfo(), result.buffer(), result.shapeInfo(), result.specialBuffer(), result.specialShapeInfo(), tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({&result}, {&arr, &tmp}); - return result; -} -template ND4J_EXPORT NDArray operator/(const double& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator/(const float& scalar, const NDArray& arr); -template ND4J_EXPORT NDArray operator/(const int& scalar, const NDArray& arr); + return result; + } + template ND4J_EXPORT NDArray operator/(const double& scalar, const NDArray& arr); + template ND4J_EXPORT NDArray operator/(const float& scalar, const NDArray& arr); + template ND4J_EXPORT NDArray operator/(const int& scalar, const NDArray& arr); //////////////////////////////////////////////////////////////////////// // addition operator array + array -template -NDArray operator+(T1&& arr1, T2&& arr2) { + template + NDArray operator+(T1&& arr1, T2&& arr2) { - if (arr1.isS() || arr2.isS()) - throw std::runtime_error("operator+(T&& arr1, T&& arr2): you can't use this method on String arrays!"); - if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) - throw sd::datatype_exception::build("operator+(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); + if (arr1.isS() || arr2.isS()) + throw std::runtime_error("operator+(T&& arr1, T&& arr2): you can't use this method on String arrays!"); + if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw sd::datatype_exception::build("operator+(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); - PointersManager pointersManager(arr1.getContext(), "operator+(T&& arr1, T&& arr2)"); + PointersManager pointersManager(arr1.getContext(), "operator+(T&& arr1, T&& arr2)"); - if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { - const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); - const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); - NDArray* result = nullptr; - if(isArr1Rvalue) - result = const_cast(&arr1); - else if(isArr2Rvalue) - result = const_cast(&arr2); - else - result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); + NDArray* result = nullptr; + if(isArr1Rvalue) + result = const_cast(&arr1); + else if(isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); - NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); - NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Add, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Add, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); - if(!isArr1Rvalue && !isArr2Rvalue) { - NDArray res = std::move(*result); - delete result; - return std::move(res); + if(!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); } - return std::move(*result); + return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), std::forward(arr2)); } - - return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), std::forward(arr2)); -} -template ND4J_EXPORT NDArray operator+(NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator+(NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator+(NDArray&& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator+(NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator+(const NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator+(const NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator+(const NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator+(NDArray&& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator+(NDArray&& arr1, NDArray&& arr2); + template ND4J_EXPORT NDArray operator+(NDArray& arr1, NDArray& arr2); + template ND4J_EXPORT NDArray operator+(NDArray& arr1, NDArray&& arr2); + template ND4J_EXPORT NDArray operator+(NDArray&& arr1, NDArray& arr2); + template ND4J_EXPORT NDArray operator+(NDArray& arr1, const NDArray& arr2); + template ND4J_EXPORT NDArray operator+(const NDArray& arr1, NDArray& arr2); + template ND4J_EXPORT NDArray operator+(const NDArray& arr1, NDArray&& arr2); + template ND4J_EXPORT NDArray operator+(const NDArray& arr1, const NDArray& arr2); + template ND4J_EXPORT NDArray operator+(NDArray&& arr1, const NDArray& arr2); + template ND4J_EXPORT NDArray operator+(NDArray&& arr1, NDArray&& arr2); //////////////////////////////////////////////////////////////////////// // addition operator array - array -template -NDArray operator-(T1&& arr1, T2&& arr2) { + template + NDArray operator-(T1&& arr1, T2&& arr2) { - if (arr1.isS() || arr2.isS()) - throw std::runtime_error("operator-(T&& arr1, T&& arr2): you can't use this method on String arrays!"); - if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) - throw sd::datatype_exception::build("operator-(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); + if (arr1.isS() || arr2.isS()) + throw std::runtime_error("operator-(T&& arr1, T&& arr2): you can't use this method on String arrays!"); + if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw sd::datatype_exception::build("operator-(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); - PointersManager pointersManager(arr1.getContext(), "operator-(T&& arr1, T&& arr2)"); + PointersManager pointersManager(arr1.getContext(), "operator-(T&& arr1, T&& arr2)"); - if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { - const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); - const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); - NDArray* result = nullptr; - if(isArr1Rvalue) - result = const_cast(&arr1); - else if(isArr2Rvalue) - result = const_cast(&arr2); - else - result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); + NDArray* result = nullptr; + if(isArr1Rvalue) + result = const_cast(&arr1); + else if(isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); - NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); - NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Subtract, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Subtract, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); - if(!isArr1Rvalue && !isArr2Rvalue) { - NDArray res = std::move(*result); - delete result; - return std::move(res); + if(!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); } - return std::move(*result); + return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), std::forward(arr2)); } - - return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), std::forward(arr2)); -} -template ND4J_EXPORT NDArray operator-(NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator-(NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator-(NDArray&& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator-(NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator-(const NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator-(const NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator-(const NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator-(NDArray&& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator-(NDArray&& arr1, NDArray&& arr2); + template ND4J_EXPORT NDArray operator-(NDArray& arr1, NDArray& arr2); + template ND4J_EXPORT NDArray operator-(NDArray& arr1, NDArray&& arr2); + template ND4J_EXPORT NDArray operator-(NDArray&& arr1, NDArray& arr2); + template ND4J_EXPORT NDArray operator-(NDArray& arr1, const NDArray& arr2); + template ND4J_EXPORT NDArray operator-(const NDArray& arr1, NDArray& arr2); + template ND4J_EXPORT NDArray operator-(const NDArray& arr1, NDArray&& arr2); + template ND4J_EXPORT NDArray operator-(const NDArray& arr1, const NDArray& arr2); + template ND4J_EXPORT NDArray operator-(NDArray&& arr1, const NDArray& arr2); + template ND4J_EXPORT NDArray operator-(NDArray&& arr1, NDArray&& arr2); //////////////////////////////////////////////////////////////////////// // multiplication operator array*array -template -NDArray operator*(T1&& arr1, T2&& arr2) { + template + NDArray operator*(T1&& arr1, T2&& arr2) { - if (arr1.isS() || arr2.isS()) - throw std::runtime_error("operator*(T&& arr1, T&& arr2): you can't use this method on String arrays!"); - if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) - throw sd::datatype_exception::build("operator*(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); + if (arr1.isS() || arr2.isS()) + throw std::runtime_error("operator*(T&& arr1, T&& arr2): you can't use this method on String arrays!"); + if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw sd::datatype_exception::build("operator*(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); - PointersManager pointersManager(arr1.getContext(), "operator*(T&& arr1, T&& arr2)"); + PointersManager pointersManager(arr1.getContext(), "operator*(T&& arr1, T&& arr2)"); - if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { - const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); - const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); - NDArray* result = nullptr; - if(isArr1Rvalue) - result = const_cast(&arr1); - else if(isArr2Rvalue) - result = const_cast(&arr2); - else - result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); + NDArray* result = nullptr; + if(isArr1Rvalue) + result = const_cast(&arr1); + else if(isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); - NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); - NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Multiply, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Multiply, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); - if(!isArr1Rvalue && !isArr2Rvalue) { - NDArray res = std::move(*result); - delete result; - return std::move(res); + if(!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); } - return std::move(*result); + return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), std::forward(arr2)); } - - return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Multiply(), std::forward(arr2)); -} -template ND4J_EXPORT NDArray operator*(NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator*(NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator*(NDArray&& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator*(NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator*(const NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator*(const NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator*(const NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator*(NDArray&& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator*(NDArray&& arr1, NDArray&& arr2); + template ND4J_EXPORT NDArray operator*(NDArray& arr1, NDArray& arr2); + template ND4J_EXPORT NDArray operator*(NDArray& arr1, NDArray&& arr2); + template ND4J_EXPORT NDArray operator*(NDArray&& arr1, NDArray& arr2); + template ND4J_EXPORT NDArray operator*(NDArray& arr1, const NDArray& arr2); + template ND4J_EXPORT NDArray operator*(const NDArray& arr1, NDArray& arr2); + template ND4J_EXPORT NDArray operator*(const NDArray& arr1, NDArray&& arr2); + template ND4J_EXPORT NDArray operator*(const NDArray& arr1, const NDArray& arr2); + template ND4J_EXPORT NDArray operator*(NDArray&& arr1, const NDArray& arr2); + template ND4J_EXPORT NDArray operator*(NDArray&& arr1, NDArray&& arr2); //////////////////////////////////////////////////////////////////////// // multiplication operator array*array -template -NDArray operator/(T1&& arr1, T2&& arr2) { + template + NDArray operator/(T1&& arr1, T2&& arr2) { - if (arr1.isS() || arr2.isS()) - throw std::runtime_error("operator/(T&& arr1, T&& arr2): you can't use this method on String arrays!"); - if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) - throw sd::datatype_exception::build("operator/(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); + if (arr1.isS() || arr2.isS()) + throw std::runtime_error("operator/(T&& arr1, T&& arr2): you can't use this method on String arrays!"); + if (!Environment::getInstance().isExperimentalBuild() && arr1.dataType() != arr2.dataType() && (arr1.dataType() != DataType::BOOL || arr2.dataType() != BOOL)) + throw sd::datatype_exception::build("operator/(T&& arr1, T&& arr2): Cannot multiply different types", arr1.dataType(), arr2.dataType()); - PointersManager pointersManager(arr1.getContext(), "operator/(T&& arr1, T&& arr2)"); + PointersManager pointersManager(arr1.getContext(), "operator/(T&& arr1, T&& arr2)"); - if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { + if (arr1.lengthOf() == arr2.lengthOf() && arr1.rankOf() == arr2.rankOf()) { - const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); - const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); + const bool isArr1Rvalue = !std::is_reference::value && !arr1.isView(); + const bool isArr2Rvalue = !std::is_reference::value && !arr2.isView(); - NDArray* result = nullptr; - if(isArr1Rvalue) - result = const_cast(&arr1); - else if(isArr2Rvalue) - result = const_cast(&arr2); - else - result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); + NDArray* result = nullptr; + if(isArr1Rvalue) + result = const_cast(&arr1); + else if(isArr2Rvalue) + result = const_cast(&arr2); + else + result = new NDArray(arr1.shapeInfo(), DataTypeUtils::pickPairwiseResultType(arr1.shapeInfo(), arr2.shapeInfo()), false, arr1.getContext()); - NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); - NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Divide, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); - NDArray::registerSpecialUse({result}, {&arr1, &arr2}); + NDArray::prepareSpecialUse({result}, {&arr1, &arr2}); + NativeOpExecutioner::execPairwiseTransform(arr1.getContext(), sd::pairwise::Divide, arr1.buffer(), arr1.shapeInfo(), arr1.specialBuffer(), arr1.specialShapeInfo(), arr2.buffer(), arr2.shapeInfo(), arr2.specialBuffer(), arr2.specialShapeInfo(), result->buffer(), result->shapeInfo(), result->specialBuffer(), result->specialShapeInfo(), nullptr); + NDArray::registerSpecialUse({result}, {&arr1, &arr2}); - if(!isArr1Rvalue && !isArr2Rvalue) { - NDArray res = std::move(*result); - delete result; - return std::move(res); + if(!isArr1Rvalue && !isArr2Rvalue) { + NDArray res = std::move(*result); + delete result; + return std::move(res); + } + + return std::move(*result); } - return std::move(*result); + return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), std::forward(arr2)); } - - return std::forward(arr1).applyTrueBroadcast(sd::BroadcastOpsTuple::Divide(), std::forward(arr2)); -} -template ND4J_EXPORT NDArray operator/(NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator/(NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator/(NDArray&& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator/(NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator/(const NDArray& arr1, NDArray& arr2); -template ND4J_EXPORT NDArray operator/(const NDArray& arr1, NDArray&& arr2); -template ND4J_EXPORT NDArray operator/(const NDArray& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator/(NDArray&& arr1, const NDArray& arr2); -template ND4J_EXPORT NDArray operator/(NDArray&& arr1, NDArray&& arr2); + template ND4J_EXPORT NDArray operator/(NDArray& arr1, NDArray& arr2); + template ND4J_EXPORT NDArray operator/(NDArray& arr1, NDArray&& arr2); + template ND4J_EXPORT NDArray operator/(NDArray&& arr1, NDArray& arr2); + template ND4J_EXPORT NDArray operator/(NDArray& arr1, const NDArray& arr2); + template ND4J_EXPORT NDArray operator/(const NDArray& arr1, NDArray& arr2); + template ND4J_EXPORT NDArray operator/(const NDArray& arr1, NDArray&& arr2); + template ND4J_EXPORT NDArray operator/(const NDArray& arr1, const NDArray& arr2); + template ND4J_EXPORT NDArray operator/(NDArray&& arr1, const NDArray& arr2); + template ND4J_EXPORT NDArray operator/(NDArray&& arr1, NDArray&& arr2); /* diff --git a/libnd4j/include/array/impl/InteropDataBuffer.cpp b/libnd4j/include/array/impl/InteropDataBuffer.cpp index e041f0079..2b80294d3 100644 --- a/libnd4j/include/array/impl/InteropDataBuffer.cpp +++ b/libnd4j/include/array/impl/InteropDataBuffer.cpp @@ -33,7 +33,8 @@ namespace sd { _offset = offset; if (_offset + length > _dataBuffer->getLenInBytes()) { - throw std::runtime_error("offset + length is higher than original length"); + this->expand(length); + nd4j_debug("Expanding data buffer length by %d\n",length); } } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp index 48aabc898..78b18628f 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_max.cpp @@ -28,42 +28,50 @@ namespace sd { CUSTOM_OP_IMPL(unsorted_segment_max, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); auto idxSegments = INPUT_VARIABLE(1); + auto reshapedSegments = *idxSegments; + if(!idxSegments->isVector() && idxSegments->rankOf() > 1) { + reshapedSegments = idxSegments->reshape('c',{idxSegments->lengthOf()},false); + } + auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_max: segment indexes array length should be equal to the input first dimension, but %ld != %ild.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_max: segment indices should be in range [0, %ld), but %ld != %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentMaxFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); + REQUIRE_TRUE(reshapedSegments.isVector(), 0, "unsorted_segment_max: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); + helpers::unsortedSegmentMaxFunctor(block.launchContext(), input, &reshapedSegments, numOfClasses, segmentedOutput); return ND4J_STATUS_OK; } DECLARE_TYPES(unsorted_segment_max) { getOpDescriptor() - ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedInputTypes(1, {ALL_INTS}) - ->setSameMode(true); + ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) + ->setAllowedInputTypes(1, {ALL_INTS}) + ->setSameMode(true); } DECLARE_SHAPE_FN(unsorted_segment_max) { + auto in = inputShape->at(0); int outRank = shape::rank(in); + Nd4jLong* outputShape = nullptr; Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - Nd4jLong* outputShape; - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + if(INPUT_VARIABLE(0)->rankOf() >= 2) { + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for(int i = 1; i < outRank; i++) + outputShape[i + 1] = shape::sizeAt(in, i); - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + } else { + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); + outputShape[0] = 1; + outputShape[1] = numOfClasses; + shape::printShapeInfo(outputShape); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + } - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); return SHAPELIST(CONSTANT(outputShape)); } @@ -75,7 +83,7 @@ namespace sd { DECLARE_TYPES(unsorted_segment_max_bp) { getOpDescriptor() ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(0, {ALL_FLOATS}) ->setAllowedInputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(2, {ALL_FLOATS}) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp index 5f2f5ff02..c78c7a8a5 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_mean.cpp @@ -27,19 +27,21 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(unsorted_segment_mean, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); + auto reshapedInput = *input; + /* if(!input->isVector()) { + reshapedInput = input->reshape('c',{input->lengthOf()},false); + }*/ + auto idxSegments = INPUT_VARIABLE(1); + auto reshapedSegments = *idxSegments; + if(!idxSegments->isVector() && idxSegments->rankOf() > 1) { + reshapedSegments = idxSegments->reshape('c',{idxSegments->lengthOf()},false); + } + auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_mean: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_mean: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_mean: segment indices should be in range [0, %ld), but %ld != %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentMeanFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); + REQUIRE_TRUE(reshapedSegments.isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); + helpers::unsortedSegmentMeanFunctor(block.launchContext(), &reshapedInput, &reshapedSegments, numOfClasses, segmentedOutput); return ND4J_STATUS_OK; } @@ -58,14 +60,23 @@ namespace sd { Nd4jLong* outputShape = nullptr; Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + if(INPUT_VARIABLE(0)->rankOf() >= 2) { + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for(int i = 1; i < outRank; i++) + outputShape[i + 1] = shape::sizeAt(in, i); - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + } else { + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); + outputShape[0] = 1; + outputShape[1] = numOfClasses; + shape::printShapeInfo(outputShape); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + } - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); return SHAPELIST(CONSTANT(outputShape)); } diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp index e0c95b7c7..93fc959f2 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_min.cpp @@ -27,37 +27,49 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(unsorted_segment_min, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); + auto reshapedInput = *input; + + auto idxSegments = INPUT_VARIABLE(1); + auto reshapedSegments = *idxSegments; + if(!idxSegments->isVector() && idxSegments->rankOf() > 1) { + reshapedSegments = idxSegments->reshape('c',{idxSegments->lengthOf()},false); + } + auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_min: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_min: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_min: segment indices should be in range [0, %ld), but %ld > %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentMinFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); + REQUIRE_TRUE(reshapedSegments.isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); + helpers::unsortedSegmentMinFunctor(block.launchContext(), &reshapedInput, &reshapedSegments, numOfClasses, segmentedOutput); return ND4J_STATUS_OK; + } DECLARE_SHAPE_FN(unsorted_segment_min) { + auto in = inputShape->at(0); int outRank = shape::rank(in); Nd4jLong* outputShape = nullptr; Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + if(INPUT_VARIABLE(0)->rankOf() >= 2) { + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for(int i = 1; i < outRank; i++) + outputShape[i + 1] = shape::sizeAt(in, i); - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + } else { + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); + outputShape[0] = 1; + outputShape[1] = numOfClasses; + shape::printShapeInfo(outputShape); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + } - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); return SHAPELIST(CONSTANT(outputShape)); } @@ -77,7 +89,7 @@ namespace sd { DECLARE_TYPES(unsorted_segment_min_bp) { getOpDescriptor() ->setAllowedOutputTypes(0, {ALL_FLOATS, ALL_INTS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(0, {ALL_FLOATS, ALL_INTS}) ->setAllowedInputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(2, {ALL_FLOATS, ALL_INTS}) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp index 43f90f699..ff778372a 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_prod.cpp @@ -27,18 +27,21 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(unsorted_segment_prod, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); + auto reshapedInput = *input; + /* if(!input->isVector()) { + reshapedInput = input->reshape('c',{input->lengthOf()},false); + }*/ + auto idxSegments = INPUT_VARIABLE(1); + auto reshapedSegments = *idxSegments; + if(!idxSegments->isVector() && idxSegments->rankOf() > 1) { + reshapedSegments = idxSegments->reshape('c',{idxSegments->lengthOf()},false); + } + auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_prod: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_prod: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong = 0; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_prod: segment indices should be in range [0, %ld), but %ld != %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentProdFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); + REQUIRE_TRUE(reshapedSegments.isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); + helpers::unsortedSegmentProdFunctor(block.launchContext(), &reshapedInput, &reshapedSegments, numOfClasses, segmentedOutput); return ND4J_STATUS_OK; } @@ -50,14 +53,23 @@ namespace sd { Nd4jLong* outputShape = nullptr; Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + if(INPUT_VARIABLE(0)->rankOf() >= 2) { + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for(int i = 1; i < outRank; i++) + outputShape[i + 1] = shape::sizeAt(in, i); - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + } else { + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); + outputShape[0] = 1; + outputShape[1] = numOfClasses; + shape::printShapeInfo(outputShape); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + } - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); return SHAPELIST(CONSTANT(outputShape)); } @@ -90,7 +102,7 @@ namespace sd { DECLARE_TYPES(unsorted_segment_prod_bp) { getOpDescriptor() ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INDICES}) + ->setAllowedOutputTypes(1, {ALL_INDICES}) ->setAllowedInputTypes(0, {ALL_FLOATS}) ->setAllowedInputTypes(1, {ALL_INDICES}) ->setAllowedInputTypes(2,{ALL_FLOATS, ALL_INTS}) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp index b4963304d..05358a17d 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sqrt_n.cpp @@ -27,18 +27,18 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(unsorted_segment_sqrt_n, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); + auto reshapedInput = *input; + auto idxSegments = INPUT_VARIABLE(1); + auto reshapedSegments = *idxSegments; + if(!idxSegments->isVector() && idxSegments->rankOf() > 1) { + reshapedSegments = idxSegments->reshape('c',{idxSegments->lengthOf()},false); + } + auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sqrt_n: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sqrt_n: segment indexes array length should be equal to the input first dimension, but %ld != %ld.", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sqrt_n: segment indices should be in range [0, %ld), but %ld != %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentSqrtNFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); + REQUIRE_TRUE(reshapedSegments.isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); + helpers::unsortedSegmentSqrtNFunctor(block.launchContext(), &reshapedInput, &reshapedSegments, numOfClasses, segmentedOutput); return ND4J_STATUS_OK; } @@ -50,14 +50,23 @@ namespace sd { Nd4jLong* outputShape = nullptr; Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + if(INPUT_VARIABLE(0)->rankOf() >= 2) { + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for(int i = 1; i < outRank; i++) + outputShape[i + 1] = shape::sizeAt(in, i); - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + } else { + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); + outputShape[0] = 1; + outputShape[1] = numOfClasses; + shape::printShapeInfo(outputShape); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + } - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); return SHAPELIST(CONSTANT(outputShape)); } @@ -75,7 +84,7 @@ namespace sd { DECLARE_TYPES(unsorted_segment_sqrt_n_bp) { getOpDescriptor() ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(0, {ALL_FLOATS}) ->setAllowedInputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(2, {ALL_FLOATS}) diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp index 3b948c0d8..0f0f758c3 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/unsorted_segment_sum.cpp @@ -27,18 +27,19 @@ namespace sd { namespace ops { CUSTOM_OP_IMPL(unsorted_segment_sum, 2, 1, false, 0, 0) { auto input = INPUT_VARIABLE(0); + auto reshapedInput = *input; + + auto idxSegments = INPUT_VARIABLE(1); + auto reshapedSegments = *idxSegments; + if(!idxSegments->isVector() || idxSegments->rankOf() > 1) { + reshapedSegments = idxSegments->reshape('c',{idxSegments->lengthOf()},false); + } + auto segmentedOutput = OUTPUT_NULLIFIED(0); Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - REQUIRE_TRUE(idxSegments->isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); - REQUIRE_TRUE(idxSegments->lengthOf() == input->sizeAt(0), 0, "unsorted_segment_sum: segment indexes array length should be equal to the input first dimension, but %ld != %ld", idxSegments->lengthOf(), input->sizeAt(0)); - - Nd4jLong wrong; - - REQUIRE_TRUE(helpers::unsortedSegmentIndicesValidate(block.launchContext(), idxSegments, numOfClasses, wrong), 0, "unsorted_segment_sum: segment indices should be in range [0, %ld), but %ld > %ld", - numOfClasses, wrong, numOfClasses); - - helpers::unsortedSegmentSumFunctor(block.launchContext(), input, idxSegments, numOfClasses, segmentedOutput); + REQUIRE_TRUE(reshapedSegments.isVector(), 0, "unsorted_segment_sum: segment indexes array should be a vector, but it rank is %i.", idxSegments->rankOf()); + helpers::unsortedSegmentSumFunctor(block.launchContext(), &reshapedInput, &reshapedSegments, numOfClasses, segmentedOutput); return ND4J_STATUS_OK; } @@ -57,14 +58,23 @@ namespace sd { Nd4jLong* outputShape = nullptr; Nd4jLong numOfClasses = block.width() == 3 ? INPUT_VARIABLE(2)->e(0) : INT_ARG(0); - ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + if(INPUT_VARIABLE(0)->rankOf() >= 2) { + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(outRank), Nd4jLong); + outputShape[0] = outRank; + outputShape[1] = numOfClasses; + for(int i = 1; i < outRank; i++) + outputShape[i + 1] = shape::sizeAt(in, i); - outputShape[0] = outRank; - outputShape[1] = numOfClasses; - for(int i = 1; i < outRank; ++i) - outputShape[i + 1] = shape::sizeAt(in, i); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + + } else { + ALLOCATE(outputShape, block.getWorkspace(), shape::shapeInfoLength(1), Nd4jLong); + outputShape[0] = 1; + outputShape[1] = numOfClasses; + shape::printShapeInfo(outputShape); + ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); + } - ShapeUtils::updateStridesAndType(outputShape, in, shape::order(in)); return SHAPELIST(CONSTANT(outputShape)); } @@ -86,7 +96,7 @@ namespace sd { DECLARE_TYPES(unsorted_segment_sum_bp) { getOpDescriptor() ->setAllowedOutputTypes(0, {ALL_FLOATS}) - ->setAllowedOutputTypes(1, {ALL_INTS}) + ->setAllowedOutputTypes(1, {ALL_INTS}) ->setAllowedInputTypes(sd::DataType::ANY) ->setSameMode(false); } diff --git a/libnd4j/include/ops/declarable/generic/random/uniform.cpp b/libnd4j/include/ops/declarable/generic/random/uniform.cpp index ffc02c204..cb7f146da 100644 --- a/libnd4j/include/ops/declarable/generic/random/uniform.cpp +++ b/libnd4j/include/ops/declarable/generic/random/uniform.cpp @@ -36,17 +36,24 @@ namespace sd { * uniform distribution * takes 1 ndarray * - * T argumens map: + * T arguments map: * TArgs[0] - min for rng * TArgs[1] - max for rng */ - CUSTOM_OP_IMPL(randomuniform, 1, 1, true, 0, 0) { + CUSTOM_OP_IMPL(randomuniform, -1, 1, true, 0, -1) { // uniform distribution auto rng = block.randomGenerator(); auto dtype = DataType::FLOAT32; if (block.getIArguments()->size()) dtype = (DataType)INT_ARG(0); + if(block.getIArguments()->size() > 1) { + auto seed = INT_ARG(1); + rng.setStates(seed,seed ^ 0xdeadbeef); + nd4j_debug("randomuniform: Setting seed %d\n",seed); + //rng.setSeed(seed); + } + auto min = block.width() > 1 ? INPUT_VARIABLE(1) : (NDArray*) nullptr; auto max = block.width() > 2 ? INPUT_VARIABLE(2) : (NDArray*) nullptr; bool disposable = false; diff --git a/libnd4j/include/ops/declarable/headers/random.h b/libnd4j/include/ops/declarable/headers/random.h index e156dae66..b4ea89969 100644 --- a/libnd4j/include/ops/declarable/headers/random.h +++ b/libnd4j/include/ops/declarable/headers/random.h @@ -50,7 +50,7 @@ namespace sd { * 0 - uniformly distributed values of given type (between min and max) */ #if NOT_EXCLUDED(OP_randomuniform) - DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0); + DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, -1); #endif /* * multinomial (categorical) random generator draws samples from a multinomial distribution diff --git a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp index 0b3ab7847..0f8d1a80e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/segment.cpp @@ -28,1067 +28,1076 @@ #include namespace sd { -namespace ops { -namespace helpers { + namespace ops { + namespace helpers { - // segment max - template - static void segmentMaxFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - //int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - Nd4jLong idx = indices->e(0); - if (input->isVector()) { - T val = input->e(0); + // segment max + template + static void segmentMaxFunctor_(NDArray* input, NDArray* indices, NDArray* output) { + //int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + Nd4jLong idx = indices->e(0); + if (input->isVector() || input->isScalar()) { + T val = input->e(0); - for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // max - val = sd::math::nd4j_max(val, input->t(e)); - } - else { - idx = indices->e(e); - val = input->t(e); - } - output->r(idx) = val; - } - } - else { - std::vector restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - auto listOfTensors = input->allTensorsAlongDimension(restDims); - auto listOfOutTensors = output->allTensorsAlongDimension(restDims); - - auto numOfClasses = output->sizeAt(0); // number of classes - std::vector> outputs(numOfClasses); - auto maxT = listOfOutTensors.at(idx); - - //int pos = 0; - maxT->assign(listOfTensors.at(0)); - - for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - - for (Nd4jLong e = 0; e < maxT->lengthOf(); e++) { - maxT->r(e) = sd::math::nd4j_max(maxT->t(e), listOfTensors.at(i)->t(e)); + for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // max + val = sd::math::nd4j_max(val, input->t(e)); + } + else { + idx = indices->e(e); + val = input->t(e); + } + output->r(idx) = val; } } else { - idx = indices->e(i); - maxT = listOfOutTensors.at(idx); - maxT->assign(listOfTensors.at(i)); - } + std::vector restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + auto listOfTensors = input->allTensorsAlongDimension(restDims); + auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + auto numOfClasses = output->sizeAt(0); // number of classes + std::vector> outputs(numOfClasses); + auto maxT = listOfOutTensors.at(idx); + + //int pos = 0; + maxT->assign(listOfTensors.at(0)); + + for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + + for (Nd4jLong e = 0; e < maxT->lengthOf(); e++) { + maxT->r(e) = sd::math::nd4j_max(maxT->t(e), listOfTensors.at(i)->t(e)); + } + } + else { + idx = indices->e(i); + maxT = listOfOutTensors.at(idx); + maxT->assign(listOfTensors.at(i)); + } + + } + } } - } - } - // segmen min - template - static void segmentMinFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - //int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - Nd4jLong idx = indices->e(0); - if (input->isVector()) { - T val = input->e(0); + // segmen min + template + static void segmentMinFunctor_(NDArray* input, NDArray* indices, NDArray* output) { + //int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + Nd4jLong idx = indices->e(0); + if (input->isVector() || input->isScalar()) { + T val = input->e(0); - for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // min - val = sd::math::nd4j_min(val, input->t(e)); - } - else { - idx = indices->e(e); - val = input->t(e); - } - output->r(idx) = val; - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - int numOfClasses = output->sizeAt(0); // number of classes - std::vector> outputs(numOfClasses); - auto minT = listOfOutTensors.at(idx); - - int pos = 0; - minT->assign(listOfTensors.at(0)); - - for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - - for (Nd4jLong e = 0; e < minT->lengthOf(); e++) { - minT->p(e, sd::math::nd4j_min(minT->e(e), listOfTensors.at(i)->e(e))); + for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // min + val = sd::math::nd4j_min(val, input->t(e)); + } + else { + idx = indices->e(e); + val = input->t(e); + } + output->r(idx) = val; } } else { - idx = indices->e(i); - minT = listOfOutTensors.at(idx); - minT->assign(listOfTensors.at(i)); - } - } - } - } + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - // segmen mean - template - static void segmentMeanFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - int idx = indices->e(0); - if (input->isVector()) { - T val = T(0.f); - int count = 0; + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // mean - val += input->e(e); - count++; - } - else { - output->p(idx, val / count); - idx = indices->e(e); - val = input->e(e); - count = 1; - } - output->p(idx, val / count); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + int numOfClasses = output->sizeAt(0); // number of classes + std::vector> outputs(numOfClasses); + auto minT = listOfOutTensors.at(idx); - auto listOfTensors = input->allTensorsAlongDimension(restDims); - auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + int pos = 0; + minT->assign(listOfTensors.at(0)); - int numOfClasses = output->sizeAt(0); // number of classes - std::vector> outputs(numOfClasses); - auto meanT = listOfOutTensors.at(idx); - int count = 1; - auto meanV = meanT->dup(); - meanV.assign(listOfTensors.at(0)); + for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { - for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - meanV.p(e, meanV.e(e) + listOfTensors.at(i)->e(e)); + for (Nd4jLong e = 0; e < minT->lengthOf(); e++) { + minT->p(e, sd::math::nd4j_min(minT->e(e), listOfTensors.at(i)->e(e))); + } } - }; - samediff::Threads::parallel_for(func, 0, meanT->lengthOf()); - - count++; - } - else { - //meanT->assign(meanV); - meanV.applyScalar(scalar::Divide, count, *meanT); - idx = indices->e(i); - meanT = listOfOutTensors.at(idx); - meanV.assign(listOfTensors.at(i)); - count = 1; - } - meanV.applyScalar(scalar::Divide, count, *meanT); - } - } - } - - template - static void segmentSumFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - int idx = indices->e(0); - if (input->isVector()) { - T val = T(0.f); - int count = 0; - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // sum - val += input->t(e); - } - else { - idx = indices->e(e); - val = input->t(e); - } - output->p(idx, val); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - auto listOfTensors = input->allTensorsAlongDimension(restDims); - auto listOfOutTensors = output->allTensorsAlongDimension(restDims); - - int numOfClasses = output->sizeAt(0); // number of classes - std::vector> outputs(numOfClasses); - auto sumT = listOfOutTensors.at(idx); - - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - sumT->p(e, sumT->e(e) + listOfTensors.at(i)->e(e)); + else { + idx = indices->e(i); + minT = listOfOutTensors.at(idx); + minT->assign(listOfTensors.at(i)); } - }; - samediff::Threads::parallel_for(func, 0, sumT->lengthOf()); - } - else { - idx = indices->e(i); - sumT = listOfOutTensors.at(idx); - sumT->assign(listOfTensors.at(i)); + } } } - } - } - template - static void segmentProdFunctor_(NDArray* input, NDArray* indices, NDArray* output) { - //int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - int idx = indices->e(0); - output->assign(1.f); - if (input->isVector()) { - T val = input->e(0); - int count = 0; + // segmen mean + template + static void segmentMeanFunctor_(NDArray* input, NDArray* indices, NDArray* output) { + int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + int idx = indices->e(0); + if (input->isVector() || input->isScalar()) { + T val = T(0.f); + int count = 0; - for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { - if (idx == indices->e(e)) { - // sum - val *= input->e(e); - } - else { - idx = indices->e(e); - val = input->e(e); - } - output->p(idx, val); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - auto listOfTensors = input->allTensorsAlongDimension(restDims); - auto listOfOutTensors = output->allTensorsAlongDimension(restDims); - - int numOfClasses = output->sizeAt(0); // number of classes - auto sumT = listOfOutTensors.at(idx); - sumT->assign(listOfTensors.at(0)); - for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { - if (indices->e(i) == idx) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - sumT->p(e, sumT->e(e) * listOfTensors.at(i)->e(e)); + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // mean + val += input->e(e); + count++; } - }; - samediff::Threads::parallel_for(func, 0, sumT->lengthOf()); + else { + output->p(idx, val / count); + idx = indices->e(e); + val = input->e(e); + count = 1; + } + output->p(idx, val / count); + } } else { - idx = indices->e(i); - sumT = listOfOutTensors.at(idx); - sumT->assign(listOfTensors.at(i)); + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + auto listOfTensors = input->allTensorsAlongDimension(restDims); + auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + + int numOfClasses = output->sizeAt(0); // number of classes + std::vector> outputs(numOfClasses); + auto meanT = listOfOutTensors.at(idx); + int count = 1; + auto meanV = meanT->dup(); + meanV.assign(listOfTensors.at(0)); + + for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + meanV.p(e, meanV.e(e) + listOfTensors.at(i)->e(e)); + } + }; + samediff::Threads::parallel_for(func, 0, meanT->lengthOf()); + + count++; + } + else { + //meanT->assign(meanV); + meanV.applyScalar(scalar::Divide, count, *meanT); + idx = indices->e(i); + meanT = listOfOutTensors.at(idx); + meanV.assign(listOfTensors.at(i)); + count = 1; + } + meanV.applyScalar(scalar::Divide, count, *meanT); + } + } + } + + template + static void segmentSumFunctor_(NDArray* input, NDArray* indices, NDArray* output) { + int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + int idx = indices->e(0); + if (input->isVector() || input->isScalar()) { + T val = T(0.f); + int count = 0; + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // sum + val += input->t(e); + } + else { + idx = indices->e(e); + val = input->t(e); + } + output->p(idx, val); + } + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + auto listOfTensors = input->allTensorsAlongDimension(restDims); + auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + + int numOfClasses = output->sizeAt(0); // number of classes + std::vector> outputs(numOfClasses); + auto sumT = listOfOutTensors.at(idx); + + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + sumT->p(e, sumT->e(e) + listOfTensors.at(i)->e(e)); + } + }; + samediff::Threads::parallel_for(func, 0, sumT->lengthOf()); + } + else { + idx = indices->e(i); + sumT = listOfOutTensors.at(idx); + sumT->assign(listOfTensors.at(i)); + } + } + } + } + + template + static void segmentProdFunctor_(NDArray* input, NDArray* indices, NDArray* output) { + //int numClasses = output->sizeAt(0); + // if input is a vector: (as if in doc sample) + int idx = indices->e(0); + output->assign(1.f); + if (input->isVector() || input->isScalar()) { + T val = input->e(0); + int count = 0; + + for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { + if (idx == indices->e(e)) { + // sum + val *= input->e(e); + } + else { + idx = indices->e(e); + val = input->e(e); + } + output->p(idx, val); + } + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + auto listOfTensors = input->allTensorsAlongDimension(restDims); + auto listOfOutTensors = output->allTensorsAlongDimension(restDims); + + int numOfClasses = output->sizeAt(0); // number of classes + auto sumT = listOfOutTensors.at(idx); + sumT->assign(listOfTensors.at(0)); + for (Nd4jLong i = 1; i < indices->lengthOf(); i++) { + if (indices->e(i) == idx) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + sumT->p(e, sumT->e(e) * listOfTensors.at(i)->e(e)); + } + }; + samediff::Threads::parallel_for(func, 0, sumT->lengthOf()); + } + else { + idx = indices->e(i); + sumT = listOfOutTensors.at(idx); + sumT->assign(listOfTensors.at(i)); + } + } } } - } - } // template // static bool segmentIndicesValidate_(NDArray* indices, NDArray& aexpected, NDArray& anOutput) { // } - void segmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentMaxFunctor_, (input, indices, output), LIBND4J_TYPES); - } - - void segmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentMinFunctor_, (input, indices, output), LIBND4J_TYPES); - } - - void segmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentMeanFunctor_, (input, indices, output), LIBND4J_TYPES); - } - - void segmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentSumFunctor_, (input, indices, output), LIBND4J_TYPES); - } - - void segmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), segmentProdFunctor_, (input, indices, output), LIBND4J_TYPES); - } - - bool segmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, NDArray& expected, NDArray& output) { - auto val = indices->e(0); - for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { - output = indices->e(e); - if (val.e(0) > output.e(0)) - return false; - val = indices->e(e); - } - - return true; - } - - //BUILD_SINGLE_TEMPLATE(template bool segmentIndicesValidate_, (NDArray*, NDArray&, NDArray&), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentProdFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentSumFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentMeanFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentMinFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - BUILD_SINGLE_TEMPLATE(template void segmentMaxFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); - // -------------------------------------------------------------------------------------------------------------- // - // Unsorted segment ops - // -------------------------------------------------------------------------------------------------------------- // - - bool unsortedSegmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, Nd4jLong expected, Nd4jLong& output) { - Nd4jLong val = indices->e(0); - - Nd4jLong maxInd = indices->argMax(); - if (indices->e(maxInd) >= expected) { - output = val; - return false; - } - output = expected; - return true; - } - - template - static void unsortedSegmentMaxFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - - // if input is a vector: (as if in doc sample) - //int idx = static_cast((*indices)(0.)); - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); - - //std::sort(idxs.begin(), idxs.end()); - - if (input->isVector()) { // 1D case - T maxVal = DataTypeUtils::max(); - output->assign(-maxVal); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - T val = input->e(fi->second.at(0)); - for (Nd4jLong idx = 1; idx < static_cast(fi->second.size()); ++idx) { - val = sd::math::nd4j_max(val, input->e(fi->second.at(idx))); - } - output->p(fi->first, val); + void segmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentMaxFunctor_, (input, indices, output), LIBND4J_TYPES); } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + void segmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentMinFunctor_, (input, indices, output), LIBND4J_TYPES); + } - T maxVal = DataTypeUtils::max(); - output->assign(-maxVal); + void segmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentMeanFunctor_, (input, indices, output), LIBND4J_TYPES); + } - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); - for (Nd4jLong idx = 1; idx < static_cast(fi->second.size()); ++idx) { - auto maxT = listOfTensors.at(fi->second.at(idx)); - for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { - T val = sd::math::nd4j_max(maxT->e(e), outputT->e(e)); + void segmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentSumFunctor_, (input, indices, output), LIBND4J_TYPES); + } - outputT->p(e, val); + void segmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), segmentProdFunctor_, (input, indices, output), LIBND4J_TYPES); + } + + bool segmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, NDArray& expected, NDArray& output) { + auto val = indices->e(0); + for (Nd4jLong e = 1; e < indices->lengthOf(); e++) { + output = indices->e(e); + if (val.e(0) > output.e(0)) + return false; + val = indices->e(e); + } + + return true; + } + + //BUILD_SINGLE_TEMPLATE(template bool segmentIndicesValidate_, (NDArray*, NDArray&, NDArray&), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void segmentProdFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void segmentSumFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void segmentMeanFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void segmentMinFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); + BUILD_SINGLE_TEMPLATE(template void segmentMaxFunctor_, (NDArray* input, NDArray* indices, NDArray* output), LIBND4J_TYPES); + // -------------------------------------------------------------------------------------------------------------- // + // Unsorted segment ops + // -------------------------------------------------------------------------------------------------------------- // + + bool unsortedSegmentIndicesValidate(sd::LaunchContext * context, NDArray* indices, Nd4jLong expected, Nd4jLong& output) { + Nd4jLong val = indices->e(0); + + Nd4jLong maxInd = indices->argMax(); + if (indices->e(maxInd) >= expected) { + output = val; + return false; + } + output = expected; + return true; + } + + template + static void unsortedSegmentMaxFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + + // if input is a vector: (as if in doc sample) + //int idx = static_cast((*indices)(0.)); + MAP_IMPL> idxs;//(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); + + + //std::sort(idxs.begin(), idxs.end()); + + if (input->isVector() || input->isScalar()) { // 1D case + T maxVal = DataTypeUtils::max(); + output->assign(-maxVal); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + T val = input->e(fi->second.at(0)); + for (Nd4jLong idx = 1; idx < static_cast(fi->second.size()); ++idx) { + val = sd::math::nd4j_max(val, input->e(fi->second.at(idx))); + } + output->p(fi->first, val); + } + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + T maxVal = DataTypeUtils::max(); + output->assign(-maxVal); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); + for (Nd4jLong idx = 0; idx < listOfTensors.size(); ++idx) { + if(idx >= fi->second.size() || fi->second.size() < 2 || fi->second.at(idx) >= listOfTensors.size()) { + continue; + } + + auto maxT = listOfTensors.at(fi->second.at(idx)); + for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { + T val = sd::math::nd4j_max(maxT->e(e), outputT->e(e)); + + outputT->p(e, val); + } + } + + + } + + + } + } + void unsortedSegmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMaxFunctor_, (input, indices, numOfClasses, output), NUMERIC_TYPES); + } + BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); + + template + static void unsortedSegmentMinFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + // if input is a vector: (as if in doc sample) + //int idx = static_cast((*indices)(0.)); + MAP_IMPL> idxs;//(indices->lengthOf()); + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); + + //std::sort(idxs.begin(), idxs.end()); + + if (input->isVector() || input->isScalar()) { // 1D case + T maxVal = DataTypeUtils::max(); + output->assign(maxVal); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + T val = input->t(fi->second.at(0)); + + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + val = sd::math::nd4j_min(val, input->t(fi->second.at(idx))); + } + output->r(fi->first) = val; + } + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + T maxVal = DataTypeUtils::max(); + output->assign(maxVal); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + auto minT = listOfTensors.at(fi->second.at(idx)); + + for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { + outputT->r(e) = sd::math::nd4j_min(minT->t(e), outputT->t(e)); + } + } + //outputT->assign(maxT); + } + } + + } + void unsortedSegmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMinFunctor_, (input, indices, numOfClasses, output), + NUMERIC_TYPES); + } + + BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); + + void unsortedSegmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + MAP_IMPL> idxs;//(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); + + //std::sort(idxs.begin(), idxs.end()); + + if (input->isVector() || input->isScalar()) { // 1D case + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + double sumValue = input->e(fi->second.at(0)); + int loop_size = fi->second.size(); + + // FIXME: parallelism here? + for (size_t idx = 1; idx < loop_size; ++idx) { + sumValue += input->e(fi->second.at(idx)); + } + + output->p(fi->first, sumValue / fi->second.size()); + } + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + // FIXME: parallelism here? + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); + Nd4jLong loopSize = fi->second.size(); + + for (Nd4jLong idx = 1; idx < loopSize; ++idx) { + auto current = listOfTensors.at(fi->second.at(idx)); + *outputT += *current; + } + (*outputT) /= double(fi->second.size()); } } } - } - } - void unsortedSegmentMaxFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMaxFunctor_, (input, indices, numOfClasses, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMaxFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); - template - static void unsortedSegmentMinFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - // if input is a vector: (as if in doc sample) - //int idx = static_cast((*indices)(0.)); - MAP_IMPL> idxs;//(indices->lengthOf()); + void unsortedSegmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + MAP_IMPL> idxs;//(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); + if (input->isVector() || input->isScalar()) { // 1D case - //std::sort(idxs.begin(), idxs.end()); + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + double sumValue = input->e(fi->second.at(0)); + Nd4jLong loop_size = fi->second.size(); - if (input->isVector()) { // 1D case - T maxVal = DataTypeUtils::max(); - output->assign(maxVal); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - T val = input->t(fi->second.at(0)); - - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - val = sd::math::nd4j_min(val, input->t(fi->second.at(idx))); - } - output->r(fi->first) = val; - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - T maxVal = DataTypeUtils::max(); - output->assign(maxVal); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - auto minT = listOfTensors.at(fi->second.at(idx)); - - for (Nd4jLong e = 0; e < outputT->lengthOf(); ++e) { - outputT->r(e) = sd::math::nd4j_min(minT->t(e), outputT->t(e)); + // FIXME: parallelism here? + for (Nd4jLong idx = 1; idx < loop_size; ++idx) { + sumValue += input->e(fi->second.at(idx)); + } + output->p(fi->first, sumValue); } } - //outputT->assign(maxT); - } - } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - } - void unsortedSegmentMinFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentMinFunctor_, (input, indices, numOfClasses, output), - NUMERIC_TYPES); - } + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - BUILD_SINGLE_TEMPLATE(template void unsortedSegmentMinFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); + Nd4jLong loop_size = fi->second.size(); - void unsortedSegmentMeanFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); - - //std::sort(idxs.begin(), idxs.end()); - - if (input->isVector()) { // 1D case - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - double sumValue = input->e(fi->second.at(0)); - int loop_size = fi->second.size(); - - // FIXME: parallelism here? - for (size_t idx = 1; idx < loop_size; ++idx) { - sumValue += input->e(fi->second.at(idx)); - } - - output->p(fi->first, sumValue / fi->second.size()); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - // FIXME: parallelism here? - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); - Nd4jLong loopSize = fi->second.size(); - - for (Nd4jLong idx = 1; idx < loopSize; ++idx) { - auto current = listOfTensors.at(fi->second.at(idx)); - *outputT += *current; - } - (*outputT) /= double(fi->second.size()); - } - } - } - - void unsortedSegmentSumFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); - - if (input->isVector()) { // 1D case - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - double sumValue = input->e(fi->second.at(0)); - Nd4jLong loop_size = fi->second.size(); - - // FIXME: parallelism here? - for (Nd4jLong idx = 1; idx < loop_size; ++idx) { - sumValue += input->e(fi->second.at(idx)); - } - output->p(fi->first, sumValue); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); - Nd4jLong loop_size = fi->second.size(); - - // FIXME: parallelism here? - for (Nd4jLong idx = 1; idx < loop_size; ++idx) { - auto current = listOfTensors.at(fi->second.at(idx)); - *(outputT) += *current; - } - //outputT->assign(maxT); - } - } - } - - template - void unsortedSegmentProdFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); - - //std::sort(idxs.begin(), idxs.end()); - - output->assign(1.f); - - if (input->isVector()) { // 1D case - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - T prodValue = input->e(fi->second.at(0)); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - prodValue *= input->e(fi->second.at(idx)); - } - output->p(fi->first, prodValue); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - auto current = listOfTensors.at(fi->second.at(idx)); - - *outputT *= *current; - } - } - } - } - - void unsortedSegmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentProdFunctor_, (input, indices, numOfClasses, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); - - void unsortedSegmentSqrtNFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL> idxs;//(indices->lengthOf()); - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) - idxs[indices->e(e)].push_back(e); - - //std::sort(idxs.begin(), idxs.end()); - - if (input->isVector()) { // 1D case - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - double sumValue = input->e(fi->second.at(0)); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - sumValue += input->e(fi->second.at(idx)); - } - output->p(fi->first, sumValue / sd::math::nd4j_sqrt(fi->second.size())); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { - auto outputT = listOfOutTensors.at(fi->first); - outputT->assign(listOfTensors.at(fi->second.at(0))); - for (size_t idx = 1; idx < fi->second.size(); ++idx) { - auto current = listOfTensors.at(fi->second.at(idx)); - *outputT += *current; - } - //outputT->assign(maxT); - (*outputT) /= sd::math::nd4j_sqrt(fi->second.size()); - } - } - } - - // -------------------------------------------------------------------------------------------------------------- // - // Backpropagate ops helpers - // -------------------------------------------------------------------------------------------------------------- // - // Sorted backpropagate ops - // - // segment max - template - int segmentMaxFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - //int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto tempRes = gradOut->dup(); - segmentMaxFunctor_(input, indices, &tempRes); - if (input->isVector()) { - Nd4jLong loop_size = input->lengthOf(); - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto classNum = indices->e(e); - if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) <= T(1.e-6)) - output->p(e, gradOut->e(classNum)); - } - }; - samediff::Threads::parallel_for(func, 0, loop_size); - } - else { - std::vector restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //int numOfClasses = tempRes.sizeAt(0); // number of classes - //std::vector> outputs(numOfClasses); - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) <= T(1.e-6)) - currentOut->p(e, currentGradOut->e(e)); + // FIXME: parallelism here? + for (Nd4jLong idx = 1; idx < loop_size; ++idx) { + auto current = listOfTensors.at(fi->second.at(idx)); + *(outputT) += *current; + } + //outputT->assign(maxT); } } - }; - - samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); - } - - return ND4J_STATUS_OK; - } - - int segmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), return segmentMaxFunctorBP_, (context, input, indices, gradOut, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int segmentMaxFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES); - - // segmen min - int segmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - NDArray tempRes = gradOut->dup(); - segmentMinFunctor(context, input, indices, &tempRes); - if (input->isVector()) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto classNum = indices->e(e); - if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) < 1.e-5) - output->p(e, gradOut->e(classNum)); - } - }; - samediff::Threads::parallel_for(func, 0, input->lengthOf()); - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //int numOfClasses = tempRes.sizeAt(0); // number of classes - //std::vector> outputs(numOfClasses); - output->assign(0.); - int pos = 0; - - auto func = PRAGMA_THREADS_FOR { - for (auto i = start; i < stop; i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) < - 1.e-5) - currentOut->p(e, currentGradOut->e(e)); - } - } - }; - - samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); - } - return ND4J_STATUS_OK; - } - - // segmen mean - int segmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - int numClasses = output->sizeAt(0); - MAP_IMPL classCount;//(numClasses); - - for (Nd4jLong count = 0; count < numClasses; ++count) { - classCount[count] = 0; - } - - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - classCount[indices->e(e)] ++; - } - - // if input is a vector: (as if in doc sample) - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum) / classCount[classNum]); } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); -; + template + void unsortedSegmentProdFunctor_(NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + MAP_IMPL> idxs;//(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); - int pos = 0; - //auto func = [&](uint64_t thread_id, uint64_t start, uint64_t stop, uint64_t increment) -> void { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); + //std::sort(idxs.begin(), idxs.end()); - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - currentOut->p(e, currentGradOut->e(e) / classCount.at(classNum)); + output->assign(1.f); + + if (input->isVector() || input->isScalar()) { // 1D case + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + T prodValue = input->e(fi->second.at(0)); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + prodValue *= input->e(fi->second.at(idx)); + } + output->p(fi->first, prodValue); } } - //}; + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - return ND4J_STATUS_OK; - } + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - int segmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + auto current = listOfTensors.at(fi->second.at(idx)); + + *outputT *= *current; + } + } + } + } + + void unsortedSegmentProdFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + BUILD_SINGLE_SELECTOR(input->dataType(), unsortedSegmentProdFunctor_, (input, indices, numOfClasses, output), NUMERIC_TYPES); + } + BUILD_SINGLE_TEMPLATE(template void unsortedSegmentProdFunctor_, (NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); + + void unsortedSegmentSqrtNFunctor(sd::LaunchContext * context, NDArray* input, NDArray* indices, Nd4jLong numOfClasses, NDArray* output) { + MAP_IMPL> idxs;//(indices->lengthOf()); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) + idxs[indices->e(e)].push_back(e); + + //std::sort(idxs.begin(), idxs.end()); + + if (input->isVector() || input->isScalar()) { // 1D case + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + double sumValue = input->e(fi->second.at(0)); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + sumValue += input->e(fi->second.at(idx)); + } + output->p(fi->first, sumValue / sd::math::nd4j_sqrt(fi->second.size())); + } + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (auto fi = idxs.begin(); fi != idxs.end(); ++fi) { + auto outputT = listOfOutTensors.at(fi->first); + outputT->assign(listOfTensors.at(fi->second.at(0))); + for (size_t idx = 1; idx < fi->second.size(); ++idx) { + auto current = listOfTensors.at(fi->second.at(idx)); + *outputT += *current; + } + //outputT->assign(maxT); + (*outputT) /= sd::math::nd4j_sqrt(fi->second.size()); + } + } + } + + // -------------------------------------------------------------------------------------------------------------- // + // Backpropagate ops helpers + // -------------------------------------------------------------------------------------------------------------- // + // Sorted backpropagate ops + // + // segment max + template + int segmentMaxFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + //int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto tempRes = gradOut->dup(); + segmentMaxFunctor_(input, indices, &tempRes); + if (input->isVector() || input->isScalar()) { + Nd4jLong loop_size = input->lengthOf(); + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto classNum = indices->e(e); + if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) <= T(1.e-6)) + output->p(e, gradOut->e(classNum)); + } + }; + samediff::Threads::parallel_for(func, 0, loop_size); + } + else { + std::vector restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + //int numOfClasses = tempRes.sizeAt(0); // number of classes + //std::vector> outputs(numOfClasses); + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (Nd4jLong e = 0; e < current->lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) <= T(1.e-6)) + currentOut->p(e, currentGradOut->e(e)); + } + } + }; + + samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); + } + + return ND4J_STATUS_OK; + } + + int segmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), return segmentMaxFunctorBP_, (context, input, indices, gradOut, output), NUMERIC_TYPES); + } + BUILD_SINGLE_TEMPLATE(template int segmentMaxFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output), NUMERIC_TYPES); + + // segmen min + int segmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + NDArray tempRes = gradOut->dup(); + segmentMinFunctor(context, input, indices, &tempRes); + if (input->isVector() || input->isScalar()) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto classNum = indices->e(e); + if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) < 1.e-5) + output->p(e, gradOut->e(classNum)); + } + }; + samediff::Threads::parallel_for(func, 0, input->lengthOf()); + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + //int numOfClasses = tempRes.sizeAt(0); // number of classes + //std::vector> outputs(numOfClasses); + output->assign(0.); + int pos = 0; + + auto func = PRAGMA_THREADS_FOR { + for (auto i = start; i < stop; i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (Nd4jLong e = 0; e < current->lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) < + 1.e-5) + currentOut->p(e, currentGradOut->e(e)); + } + } + }; + + samediff::Threads::parallel_tad(func, 0, indices->lengthOf()); + } + return ND4J_STATUS_OK; + } + + // segmen mean + int segmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + int numClasses = output->sizeAt(0); + MAP_IMPL classCount;//(numClasses); + + for (Nd4jLong count = 0; count < numClasses; ++count) { + classCount[count] = 0; + } + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + classCount[indices->e(e)] ++; + } + + // if input is a vector: (as if in doc sample) + if (input->isVector() || input->isScalar()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum) / classCount[classNum]); + } + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + ; + + int pos = 0; + //auto func = [&](uint64_t thread_id, uint64_t start, uint64_t stop, uint64_t increment) -> void { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (Nd4jLong e = 0; e < current->lengthOf(); e++) { + currentOut->p(e, currentGradOut->e(e) / classCount.at(classNum)); + } + } + //}; + + //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + return ND4J_STATUS_OK; + } + + int segmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { // int numClasses = output->sizeAt(0); - // if input is a vector: (as if in doc sample) - Nd4jLong idx = indices->e(0); - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum)); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - currentOut->assign(currentGradOut); - } - //}; - - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - return Status::OK(); - } - - int segmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { - auto tempRes = gradOut->dup(); - segmentProdFunctor(context, input, indices, &tempRes); - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum) * tempRes.e(classNum)/ input->e(e)); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //int numOfClasses = tempRes.sizeAt(0); // number of classes - //std::vector> outputs(numOfClasses); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - auto currentFFOut = listOfBPTensors.at(classNum); - - currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current)); - } - //}; - - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - - return ND4J_STATUS_OK; - } - - // -------------------------------------------------------------------------------------------------------------- // - // Unsorted backpropagate segment ops - // -------------------------------------------------------------------------------------------------------------- // - - template - static int unsortedSegmentMaxFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { -// int numOfClasses = gradOut->sizeAt(0); - // if input is a vector: (as if in doc sample) - auto tempRes = gradOut->dup(); - unsortedSegmentMaxFunctor(context, input, indices, numOfClasses, &tempRes); - if (input->isVector()) { - - for (Nd4jLong e = 0; e < input->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) < 1.e-5) - output->p(e, gradOut->e(classNum)); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - Nd4jLong classNum = indices->e(i); - NDArray* current = listOfTensors.at(i); - NDArray* currentOut = listOfOutTensors.at(i); - NDArray* currentGradOut = listOfGradOuts.at(classNum); - for (int e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) < 1.e-5) - currentOut->p(e, currentGradOut->e(e)); - } - } - } - - return ND4J_STATUS_OK; - } - - int unsortedSegmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); - - template - static int unsortedSegmentMinFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - auto tempRes = gradOut->dup(); - unsortedSegmentMinFunctor(context, input, indices, numOfClasses, &tempRes); - if (input->isVector()) { - - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto classNum = indices->e(e); - if (sd::math::nd4j_abs(tempRes.t(classNum) - input->t(e)) < 1.e-6) - output->r(e) = gradOut->t(classNum); - } - }; - - samediff::Threads::parallel_for(func, 0, input->lengthOf()); - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (Nd4jLong e = 0; e < current->lengthOf(); e++) { - if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->t(e) - current->t(e)) < 1.e-6) - currentOut->r(e) = currentGradOut->t(e); + // if input is a vector: (as if in doc sample) + Nd4jLong idx = indices->e(0); + if (input->isVector() || input->isScalar()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum)); } } - //}; + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - return ND4J_STATUS_OK; - } + //auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); - int unsortedSegmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - BUILD_SINGLE_SELECTOR(output->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES); - } - BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); + currentOut->assign(currentGradOut); + } + //}; - int unsortedSegmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - - MAP_IMPL classCount;//(numClasses); - - for (Nd4jLong count = 0; count < numOfClasses; ++count) { - classCount[count] = 0; - } - - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - classCount[indices->e(e)]++; - } - - // if input is a vector: (as if in doc sample) - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum) / classCount[classNum]); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - Nd4jLong classNum = indices->e(i); - NDArray* current = listOfTensors.at(i); - NDArray* currentOut = listOfOutTensors.at(i); - NDArray* currentGradOut = listOfGradOuts.at(classNum); - currentOut->assign(*currentGradOut / double(classCount[classNum])); - } - } - return ND4J_STATUS_OK; - } - - int unsortedSegmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - - // if input is a vector: (as if in doc sample) - Nd4jLong idx = indices->e(0); - if (input->isVector()) { - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - Nd4jLong classNum = indices->e(e); - output->p(e, gradOut->e(classNum)); - } - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - currentOut->assign(currentGradOut); + //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); } - //}; + return Status::OK(); + } - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - return Status::OK(); - } - - int unsortedSegmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - - auto tempRes = gradOut->dup(); - unsortedSegmentProdFunctor(context, input, indices, numOfClasses, &tempRes); - if (input->isVector()) { - auto func = PRAGMA_THREADS_FOR { - for (auto e = start; e < stop; e++) { - auto classNum = indices->e(e); - output->p(e, gradOut->e(classNum) * tempRes.e(classNum) / input->e(e)); + int segmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, NDArray* output) { + auto tempRes = gradOut->dup(); + segmentProdFunctor(context, input, indices, &tempRes); + if (input->isVector() || input->isScalar()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum) * tempRes.e(classNum)/ input->e(e)); + } } - }; + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); - samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); - ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); - ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + //int numOfClasses = tempRes.sizeAt(0); // number of classes + //std::vector> outputs(numOfClasses); - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - auto currentFFOut = listOfBPTensors.at(classNum); + //auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + auto currentFFOut = listOfBPTensors.at(classNum); - currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current)); + currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current)); + } + //}; + + //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); } - //}; - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } + return ND4J_STATUS_OK; + } - return Status::OK(); - } + // -------------------------------------------------------------------------------------------------------------- // + // Unsorted backpropagate segment ops + // -------------------------------------------------------------------------------------------------------------- // + + template + static int unsortedSegmentMaxFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { +// int numOfClasses = gradOut->sizeAt(0); + // if input is a vector: (as if in doc sample) + auto tempRes = gradOut->dup(); + unsortedSegmentMaxFunctor(context, input, indices, numOfClasses, &tempRes); + if (input->isVector() || input->isScalar()) { + + for (Nd4jLong e = 0; e < input->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + if (sd::math::nd4j_abs(tempRes.e(classNum) - input->e(e)) < 1.e-5) + output->p(e, gradOut->e(classNum)); + } + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + Nd4jLong classNum = indices->e(i); + NDArray* current = listOfTensors.at(i); + NDArray* currentOut = listOfOutTensors.at(i); + NDArray* currentGradOut = listOfGradOuts.at(classNum); + for (int e = 0; e < current->lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->e(e) - current->e(e)) < 1.e-5) + currentOut->p(e, currentGradOut->e(e)); + } + } + } + + return ND4J_STATUS_OK; + } + + int unsortedSegmentMaxFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), return unsortedSegmentMaxFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES); + } + BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMaxFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); + + template + static int unsortedSegmentMinFunctorBP_(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + auto tempRes = gradOut->dup(); + unsortedSegmentMinFunctor(context, input, indices, numOfClasses, &tempRes); + if (input->isVector() || input->isScalar()) { + + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto classNum = indices->e(e); + if (sd::math::nd4j_abs(tempRes.t(classNum) - input->t(e)) < 1.e-6) + output->r(e) = gradOut->t(classNum); + } + }; + + samediff::Threads::parallel_for(func, 0, input->lengthOf()); + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + //auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (Nd4jLong e = 0; e < current->lengthOf(); e++) { + if (sd::math::nd4j_abs(listOfBPTensors.at(classNum)->t(e) - current->t(e)) < 1.e-6) + currentOut->r(e) = currentGradOut->t(e); + } + } + //}; + + //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + + return ND4J_STATUS_OK; + } + + int unsortedSegmentMinFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + BUILD_SINGLE_SELECTOR(output->dataType(), return unsortedSegmentMinFunctorBP_, (context, input, indices, gradOut, numOfClasses, output), NUMERIC_TYPES); + } + BUILD_SINGLE_TEMPLATE(template int unsortedSegmentMinFunctorBP_, (sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output), NUMERIC_TYPES); + + int unsortedSegmentMeanFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + + MAP_IMPL classCount;//(numClasses); + + for (Nd4jLong count = 0; count < numOfClasses; ++count) { + classCount[count] = 0; + } + + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + classCount[indices->e(e)]++; + } + + // if input is a vector: (as if in doc sample) + if (input->isVector() || input->isScalar()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum) / classCount[classNum]); + } + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + Nd4jLong classNum = indices->e(i); + NDArray* current = listOfTensors.at(i); + NDArray* currentOut = listOfOutTensors.at(i); + NDArray* currentGradOut = listOfGradOuts.at(classNum); + currentOut->assign(*currentGradOut / double(classCount[classNum])); + } + } + return ND4J_STATUS_OK; + } + + int unsortedSegmentSumFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + + // if input is a vector: (as if in doc sample) + Nd4jLong idx = indices->e(0); + if (input->isVector() || input->isScalar()) { + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + Nd4jLong classNum = indices->e(e); + output->p(e, gradOut->e(classNum)); + } + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + //auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + currentOut->assign(currentGradOut); + } + //}; + + //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + return Status::OK(); + } + + int unsortedSegmentProdFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + + auto tempRes = gradOut->dup(); + unsortedSegmentProdFunctor(context, input, indices, numOfClasses, &tempRes); + if (input->isVector() || input->isScalar()) { + auto func = PRAGMA_THREADS_FOR { + for (auto e = start; e < stop; e++) { + auto classNum = indices->e(e); + output->p(e, gradOut->e(classNum) * tempRes.e(classNum) / input->e(e)); + } + }; + + samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfBPTensors = tempRes.allTensorsAlongDimension(restDims); + ResultSet listOfGradOuts = gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors = input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors = output->allTensorsAlongDimension(restDims); + + //auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + auto currentFFOut = listOfBPTensors.at(classNum); + + currentOut->assign((*currentFFOut) * (*currentGradOut) / (*current)); + } + //}; + + //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + + return Status::OK(); + } // template - int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { - MAP_IMPL classCount;//(numClasses); + int unsortedSegmentSqrtNFunctorBP(sd::LaunchContext * context, NDArray* input, NDArray* indices, NDArray* gradOut, Nd4jLong numOfClasses, NDArray* output) { + MAP_IMPL classCount;//(numClasses); - for (Nd4jLong count = 0; count < numOfClasses; ++count) { - classCount[count] = 0; - } - - for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { - classCount[indices->e(e)]++; - } - - // if input is a vector: (as if in doc sample) - if (input->isVector()) { - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { - auto classNum = indices->e(e); - output->p(e, gradOut->e(classNum) / sd::math::nd4j_sqrt(classCount[classNum])); + for (Nd4jLong count = 0; count < numOfClasses; ++count) { + classCount[count] = 0; } - //}; - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); - } - else { - auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + for (Nd4jLong e = 0; e < indices->lengthOf(); ++e) { + classCount[indices->e(e)]++; + } - ResultSet listOfGradOuts =gradOut->allTensorsAlongDimension(restDims); - ResultSet listOfTensors =input->allTensorsAlongDimension(restDims); - ResultSet listOfOutTensors =output->allTensorsAlongDimension(restDims); - - //auto func = PRAGMA_THREADS_FOR { - for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { - auto classNum = indices->e(i); - auto current = listOfTensors.at(i); - auto currentOut = listOfOutTensors.at(i); - auto currentGradOut = listOfGradOuts.at(classNum); - - for (int e = 0; e < current->lengthOf(); e++) { - currentOut->p(e, currentGradOut->e(e) / sd::math::nd4j_sqrt(classCount[classNum])); + // if input is a vector: (as if in doc sample) + if (input->isVector() || input->isScalar()) { + //auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong e = 0; e < indices->lengthOf(); e++) { + auto classNum = indices->e(e); + output->p(e, gradOut->e(classNum) / sd::math::nd4j_sqrt(classCount[classNum])); } + //}; + + //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); } - //}; + else { + auto restDims = ShapeUtils::evalDimsToExclude(input->rankOf(), {0}); + + ResultSet listOfGradOuts =gradOut->allTensorsAlongDimension(restDims); + ResultSet listOfTensors =input->allTensorsAlongDimension(restDims); + ResultSet listOfOutTensors =output->allTensorsAlongDimension(restDims); + + //auto func = PRAGMA_THREADS_FOR { + for (Nd4jLong i = 0; i < indices->lengthOf(); i++) { + auto classNum = indices->e(i); + auto current = listOfTensors.at(i); + auto currentOut = listOfOutTensors.at(i); + auto currentGradOut = listOfGradOuts.at(classNum); + + for (int e = 0; e < current->lengthOf(); e++) { + currentOut->p(e, currentGradOut->e(e) / sd::math::nd4j_sqrt(classCount[classNum])); + } + } + //}; + + //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); + } + return Status::OK(); + } - //samediff::Threads::parallel_for(func, 0, indices->lengthOf()); } - return Status::OK(); } - -} -} } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java index 930947772..c3e1f8be3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/internal/InferenceSession.java @@ -672,6 +672,7 @@ public class InferenceSession extends AbstractSession l = tensorArrays.get(tArr); Preconditions.checkState(l != null, "Could not find TensorArray: %s", tArr); @@ -703,6 +704,14 @@ public class InferenceSession extends AbstractSession 0) { get = get.reshape(); } + + //reflect the expanded storage + if(outIdx >= l.size()) { + while(l.size() < outIdx) { + l.add(null); + } + } + l.set(outIdx, get); //Add dependency for values array until end of execution diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java index 89b3b505b..4b4b71d49 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/Concat.java @@ -146,6 +146,10 @@ public class Concat extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List dataTypes){ + if(!dArguments.isEmpty()) { + return Collections.singletonList(dArguments.get(0)); + } + DataType first = dataTypes.get(0); for( int i = 1; i < dataTypes.size() - (isDynamicAxis ? 1 : 0); i++) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/BinCount.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/BinCount.java index 1632c15fb..b48bf826e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/BinCount.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/BinCount.java @@ -89,7 +89,7 @@ public class BinCount extends DynamicCustomOp { inputTypes, getClass()); //If weights present, same type as weights. Otherwise specified dtype - if(inputTypes.size() == 2 || inputTypes.size() == 4) { + if(inputTypes.size() >= 2) { //weights available case or TF import case (args 2/3 are min/max) return Collections.singletonList(inputTypes.get(1)); } else { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java index 224d9765b..8550fae38 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Fill.java @@ -138,7 +138,10 @@ public class Fill extends DynamicCustomOp { } @Override - public List calculateOutputDataTypes(List dataTypes){ + public List calculateOutputDataTypes(List dataTypes) { + if(!dArguments.isEmpty()) { + return Collections.singletonList(dArguments.get(0)); + } //1 or 2 possible: 2 for TF import (fill with specified value Preconditions.checkState(dataTypes != null && (dataTypes.size() == 1 || dataTypes.size() == 2), "Expected 1 or 2 input datatypes for %s, got %s", getClass(), dataTypes); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java index 1b3af85a4..f5249b4bf 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/same/Identity.java @@ -27,6 +27,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -73,8 +74,10 @@ public class Identity extends BaseDynamicTransformOp { } @Override - public List calculateOutputDataTypes(List dataTypes){ + public List calculateOutputDataTypes(List dataTypes) { Preconditions.checkState(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got input %s", getClass(), dataTypes); + if(!dArguments.isEmpty()) + return Arrays.asList(dArguments.get(0)); return dataTypes; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java index ef34fa1bc..777822981 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMax.java @@ -65,6 +65,9 @@ public class UnsortedSegmentMax extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ + if(!dArguments.isEmpty()) { + return Collections.singletonList(dArguments.get(0)); + } Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java index 5bdada4e5..1f3a206d9 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMean.java @@ -62,6 +62,9 @@ public class UnsortedSegmentMean extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ + if(!dArguments.isEmpty()) { + return Collections.singletonList(dArguments.get(0)); + } Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java index 418095a01..aeb74ae15 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentMin.java @@ -66,6 +66,9 @@ public class UnsortedSegmentMin extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ + if(!dArguments.isEmpty()) { + return Collections.singletonList(dArguments.get(0)); + } Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java index a47dd7726..297f61afb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentProd.java @@ -66,8 +66,11 @@ public class UnsortedSegmentProd extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ + if(!dArguments.isEmpty()) { + return Collections.singletonList(dArguments.get(0)); + } Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), - "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + "Expected exactly at least 2 input data types for %s, got %s", getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java index e9fc3c77d..b0738e1f4 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/segment/UnsortedSegmentSqrtN.java @@ -30,6 +30,7 @@ import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.transforms.segment.bp.UnsortedSegmentSqrtNBp; import java.util.ArrayList; +import java.util.Collections; import java.util.List; @NoArgsConstructor @@ -61,10 +62,14 @@ public class UnsortedSegmentSqrtN extends DynamicCustomOp { @Override public List calculateOutputDataTypes(List inputDataTypes){ + if(!dArguments.isEmpty()) { + return Collections.singletonList(dArguments.get(0)); + } + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); List out = new ArrayList<>(); - for( int i=0; i calculateOutputDataTypes(List inputDataTypes){ + public List calculateOutputDataTypes(List inputDataTypes) { + if(!dArguments.isEmpty()) { + return Collections.singletonList(dArguments.get(0)); + } Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 2 || inputDataTypes.size() == 3), "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); //TODO Allow customizing output type diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/RandomFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/RandomFactory.java index 74215b76b..93911850d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/RandomFactory.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/RandomFactory.java @@ -87,7 +87,8 @@ public class RandomFactory { } /** - * This method returns new onject implementing Random interface, initialized with seed value, with size of elements in buffer + * This method returns a new object implementing {@link Random} + * interface, initialized with seed value, with size of elements in buffer * * @param seed rng seed * @param size size of underlying buffer diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java index 3b512a4f0..06817a6c9 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java @@ -69,28 +69,40 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a * the status of the test failing. No tests will run. */ public final static List EXECUTE_ONLY_MODELS = Arrays.asList( - "conv2d_transpose/channels_last_b1_k2_s2_SAME", - "conv2d_transpose/channels_last_b1_k2_s1_SAME", - "bincount/rank1", - "bincount/rank1_weights", - "bincount/rank1_max5", - "emptyArrayTests/zeros/ones_rank3", - "conv2d_transpose/channels_last_b2_k2_s1_SAME_nobias", - "emptyArrayTests/identity_n/rank3.", - "emptyReduceAxisTests/reduce_sum/rank1", - "emptyReduceAxisTests/reduce_sum/rank1_keep", - "emptyReduceAxisTests/reduce_sum/rank3", - "emptyReduceAxisTests/reduce_any/rank2", - "embedding_lookup/rank2_multiple_div_nomaxnorm", - "emptyReduceAxisTests/reduce_all/rank2_keep", - "conv2d_transpose/channels_first_b1_k2_s1_SAME_sigmoid", - "conv2d_transpose/channels_first_b1_k2_s1_SAME_elu", - "emptyReduceAxisTests/reduce_prod/rank1", - "conv2d_transpose/channels_first_b2_k2_s1_SAME_nobias", - "conv2d_transpose/channels_last_b2_k2_s1_SAME_regularizers", - "conv2d_transpose/channels_last_b1_k2_s1_SAME_elu", - "conv2d_transpose/channels_first_b1_k2_s1_SAME_selu_nobias", - "embedding_lookup/rank2_multiple_mod_maxnorm1" + /*"layers_dropout/rank2_d01_train", + "layers_dropout/rank4_d05_train", + "layers_dropout/rank3_d05_train_mask2", + "layers_dropout/rank4_d05_train_mask", + "layers_dropout/rank3_d05_train_mask1", + "layers_dropout/rank2_d09_train", + "layers_dropout/rank2_d05_train",*/ + /* "primitive_gru_dynamic", + "layers_dropout/rank4_d05_train", + "fused_batch_norm/float16_nhwc", + "rnn/lstmblockcell/dynamic_b1_n5-3_ts4_noPH_noClip_fB1_noIS_withTM", + "rnn/lstmcell/dynamic_b1_nIn5_nOut3_ts4_noPH_noClip_fB1_Tanh_noIS_float_withTM", + "rnn/grublockcellv2/dynamic_b1_n3-2_ts1_noIS_noTM"*/ + /* "unsorted_segment/unsorted_segment_mean_rank3", + "unsorted_segment/unsorted_segment_sqrt_n_rank2", + "unsorted_segment/unsorted_segment_mean_rank2", + "unsorted_segment/unsorted_segment_mean_rank3", + "unsorted_segment/unsorted_segment_sum_rank3", + "unsorted_segment/unsorted_segment_min_rank2", + "unsorted_segment/unsorted_segment_prod_rank2", + "unsorted_segment/unsorted_segment_max_rank2",*/ + "bincount/rank0_weights", + "bincount/rank2_weights" + /* "compare_and_bitpack/bool", + "compare_and_bitpack/float32", + "compare_and_bitpack/float64", + "compare_and_bitpack/half", + "compare_and_bitpack/int32", + "compare_and_bitpack/int8", + "compare_and_bitpack/int64", + "compare_and_bitpack/int16"*/ + + + ); public static final String[] IGNORE_REGEXES = new String[]{ @@ -98,7 +110,12 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a // Still failing 2020/04/27 java.lang.IllegalStateException: Requested output variable Bincount does not exist in SameDiff instance //Invalid test cases. Verified by running graph against actual TF. "slogdet/.*", - + //IGNORE THIS: the TF results from comparing against an actual TF java run compared to this seem to be different. + "fused_batch_norm/float16_nhwc", + //Don't bother to test RNG. We can test subsets of ops with dropout to make sure they are consistent + //These tests have random uniform and other RNG in them that don't need to be perfectly compatible to be acceptable. + //We need different test cases here. + "layers_dropout/.*", //TODO floormod and truncatemod behave differently - i.e., "c" vs. "python" semantics. Need to check implementations too // Still failing 2020/04/27 java.lang.IllegalStateException: Could not find class for TF Ops: TruncateMod "truncatemod/.*", @@ -109,15 +126,11 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a //2019/09/11 - Couple of tests failing (InferenceSession issues) // Still failing 2020/04/27 Requested output variable concat does not exist in SameDiff instance - "rnn/bstack/d_.*", - //2019/05/21 - Failing on AVX2/512 intermittently (Linux, OSX), passing elsewhere - //"unsorted_segment/.*", //2019/05/21 - Failing on windows-x86_64-cuda-9.2 only - "conv_4", "g_09", - //"unsorted_segment/unsorted_segment_mean_rank2", //2019/05/28 - JVM crash on ppc64le only - See issue 7657 "g_11", @@ -130,13 +143,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a // Still failing 2020/04/27 java.lang.IllegalStateException: Could not find descriptor for op: deconv3d_tf - class: org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DTF "conv3d_transpose.*", - //2019/11/15 - mapping is not present yet https://github.com/eclipse/deeplearning4j/issues/8397 + //2019/11/15 - mapping is not present yet https://github.com/eclipse/deepleRaggedRange arning4j/issues/8397 // Still failing 2020/04/27 java.lang.AssertionError: Predictions do not match on ragged/reduce_mean/2d_a1, node RaggedReduceMean/truediv "ragged/reduce_mean/.*", - // 01.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8898 - "primitive_gru", - //08.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8927 "random_gamma/.*", @@ -144,15 +154,14 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a //08.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8928 "Conv3DBackpropInputV2/.*", - //12.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8940 - "compare_and_bitpack/.*", //12.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8946 "non_max_suppression_v4/.*","non_max_suppression_v5/.*", - // 18.05.2020 - https://github.com/eclipse/deeplearning4j/issues/8963 + // 18.05.2020 - :wq:wq + "random_uniform_int/.*", "random_uniform/.*", "random_poisson_v2/.*" @@ -163,10 +172,11 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a If a test name matches any regex here, an ExecPrintListener will be added to the listeners, and all output arrays will be printed during execution */ - private final List debugModeRegexes = null; //Arrays.asList("resize_nearest_neighbor/.*", "add_n.*"); + private final List debugModeRegexes = Arrays.asList("fused_batch_norm/float16_nhwc"); @BeforeClass public static void beforeClass() { + Nd4j.scalar(1.0); Nd4j.setDataType(DataType.FLOAT); Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.SCOPE_PANIC); } diff --git a/nd4j/nd4j-backends/nd4j-tests/variables-added-new.txt b/nd4j/nd4j-backends/nd4j-tests/variables-added-new.txt index 80e6a0487..6acd68d2c 100644 --- a/nd4j/nd4j-backends/nd4j-tests/variables-added-new.txt +++ b/nd4j/nd4j-backends/nd4j-tests/variables-added-new.txt @@ -1,3 +1,3 @@ in_0/read,in_0/read -MaxPoolWithArgmax,MaxPoolWithArgmax -MaxPoolWithArgmax:1,MaxPoolWithArgmax +in_1/read,in_1/read +UnsortedSegmentSum,UnsortedSegmentSum diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/IRProtobufExtensions.kt b/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/IRProtobufExtensions.kt index dc2489d6b..e9a19199c 100644 --- a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/IRProtobufExtensions.kt +++ b/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/IRProtobufExtensions.kt @@ -449,7 +449,7 @@ fun loadDataBufferFromRawData(inputTensor: TensorNamespace.TensorProto): INDArra val rawDataBuffer = Nd4j.createBuffer(byteBuffer, dtype, totalLen, 0) if(shape.isNotEmpty() && totalLen > 0) { if(rawDataBuffer.length() > 1) - return Nd4j.create(rawDataBuffer).reshape(*shape) + return Nd4j.create(rawDataBuffer).reshape('c',*shape) return Nd4j.empty(dtype) } return Nd4j.create(rawDataBuffer) diff --git a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraph.kt b/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraph.kt index 945c0aea8..8e7eb2e74 100644 --- a/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraph.kt +++ b/nd4j/samediff-import/samediff-import-api/src/main/kotlin/org/nd4j/samediff/frameworkimport/ImportGraph.kt @@ -443,6 +443,7 @@ open class ImportGraph input.argType == OpNamespace.ArgDescriptor.ArgType.INPUT_TENSOR} .sortedBy { argDescriptor -> argDescriptor.argIndex } + val numInputsToTake = resolvedArgInputs.size if(numInputsToTake != inNames.size) { @@ -496,17 +497,6 @@ open class ImportGraph input.nodeName() } val skipValidation = setOf("parallel_stack/ExpandDims/dim") //assertEquals(output.keys,output2.keys) - /* val notEquals = HashSet() + val notEquals = HashSet() + val notEqualsTf = HashSet() names.forEach { val value = output[it] val value2 = output2[it] + val tfValue = tfOutput[it] if(value!! != (value2!!)) { val oldOps = importedGraph.ops[it] val newOps = graph.ops[it] @@ -128,10 +138,19 @@ class TestTensorflowIR { val newVar = graph.variables[it] notEquals.add(it) } - }*/ - //println(notEquals) + if(tfValue!! != (value!!)) { + val oldOps = importedGraph.ops[it] + val newOps = graph.ops[it] + val oldVar = importedGraph.variables[it] + val newVar = graph.variables[it] + notEqualsTf.add(it) + } + } + println(notEquals) + println(notEqualsTf) + println() // assertEquals(output,output2) //assertEquals(tfOutput,output) } diff --git a/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt b/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt index 05d47dcf0..e6cf75dee 100644 --- a/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt +++ b/nd4j/samediff-import/samediff-import-tensorflow/tensorflow-processes.pbtxt @@ -1250,6 +1250,47 @@ mappings { ruleType: "tensor" inputFrameworkOpName: "StatelessRandomUniform" } + rule { + ruleName: "argdescriptorconstant" + functionName: "argdescriptorconstant" + inputFloatName: "max" + ruleType: "attribute" + transformerArgs { + key: "value" + transformerArgs { + name: "max" + doubleValue: 1.0 + argType: DOUBLE + argIndex: 1 + } + } + inputFrameworkOpName: "StatelessRandomUniform" + } + rule { + ruleName: "argdescriptorconstant" + functionName: "argdescriptorconstant" + inputFloatName: "min" + ruleType: "attribute" + transformerArgs { + key: "value" + transformerArgs { + name: "min" + argType: DOUBLE + } + } + inputFrameworkOpName: "StatelessRandomUniform" + } + rule { + ruleName: "ndarraytointattributevalue" + functionName: "ndarraytointattributevalue" + outputIntName: "seed" + inputToOutput { + key: "seed" + value: "seed" + } + ruleType: "attribute" + inputFrameworkOpName: "StatelessRandomUniform" + } rule { ruleName: "datatypetoint" functionName: "datatypetoint" @@ -1274,140 +1315,6 @@ mappings { ruleType: "attribute" inputFrameworkOpName: "StatelessRandomUniform" } - rule { - ruleName: "argdescriptorconstant" - functionName: "argdescriptorconstant" - inputFloatName: "min" - inputFloatName: "max" - inputTensorName: "min" - inputTensorName: "max" - ruleType: "attribute" - transformerArgs { - key: "value" - transformerArgs { - name: "min" - argType: DOUBLE - } - transformerArgs { - name: "max" - doubleValue: 1.0 - argType: DOUBLE - argIndex: 1 - } - transformerArgs { - name: "min" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 1 - } - transformerArgs { - name: "max" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 2 - } - } - transformerArgs { - key: "value" - transformerArgs { - name: "min" - argType: DOUBLE - } - transformerArgs { - name: "max" - doubleValue: 1.0 - argType: DOUBLE - argIndex: 1 - } - transformerArgs { - name: "min" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 1 - } - transformerArgs { - name: "max" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 2 - } - } - transformerArgs { - key: "value" - transformerArgs { - name: "min" - argType: DOUBLE - } - transformerArgs { - name: "max" - doubleValue: 1.0 - argType: DOUBLE - argIndex: 1 - } - transformerArgs { - name: "min" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 1 - } - transformerArgs { - name: "max" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 2 - } - } - transformerArgs { - key: "value" - transformerArgs { - name: "min" - argType: DOUBLE - } - transformerArgs { - name: "max" - doubleValue: 1.0 - argType: DOUBLE - argIndex: 1 - } - transformerArgs { - name: "min" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 1 - } - transformerArgs { - name: "max" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 2 - } - } - inputFrameworkOpName: "StatelessRandomUniform" - } } mappings { frameworkName: "tensorflow" @@ -1721,16 +1628,6 @@ mappings { ruleType: "tensor" inputFrameworkOpName: "Squeeze" } - rule { - ruleName: "listnumbertondarray" - functionName: "listnumbertondarray" - inputToOutput { - key: "a" - value: "squeeze_dims" - } - ruleType: "attribute" - inputFrameworkOpName: "Squeeze" - } rule { ruleName: "listnumbertolistnumber" functionName: "listnumbertolistnumber" @@ -1787,8 +1684,10 @@ mappings { functionName: "ndarraymapping" inputTensorName: "data" inputTensorName: "segment_ids" + inputTensorName: "num_segments" outputTensorName: "input" outputTensorName: "idxSegments" + outputTensorName: "numSegments" inputToOutput { key: "input" value: "data" @@ -1797,13 +1696,16 @@ mappings { key: "idxSegments" value: "segment_ids" } + inputToOutput { + key: "numSegments" + value: "num_segments" + } ruleType: "tensor" inputFrameworkOpName: "UnsortedSegmentProd" } rule { ruleName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute" - outputIntName: "numSegments" inputToOutput { key: "numSegments" value: "num_segments" @@ -6547,6 +6449,36 @@ mappings { ruleType: "tensor" inputFrameworkOpName: "RandomUniform" } + rule { + ruleName: "argdescriptorconstant" + functionName: "argdescriptorconstant" + inputFloatName: "max" + ruleType: "attribute" + transformerArgs { + key: "value" + transformerArgs { + name: "max" + doubleValue: 1.0 + argType: DOUBLE + argIndex: 1 + } + } + inputFrameworkOpName: "RandomUniform" + } + rule { + ruleName: "argdescriptorconstant" + functionName: "argdescriptorconstant" + inputFloatName: "min" + ruleType: "attribute" + transformerArgs { + key: "value" + transformerArgs { + name: "min" + argType: DOUBLE + } + } + inputFrameworkOpName: "RandomUniform" + } rule { ruleName: "datatypetoint" functionName: "datatypetoint" @@ -6562,149 +6494,21 @@ mappings { rule { ruleName: "valuemapping" functionName: "valuemapping" + inputIntName: "seed" + outputIntName: "seed" inputDataTypeName: "dtype" outputDataTypeName: "dataType" inputToOutput { key: "dataType" value: "dtype" } + inputToOutput { + key: "seed" + value: "seed" + } ruleType: "attribute" inputFrameworkOpName: "RandomUniform" } - rule { - ruleName: "argdescriptorconstant" - functionName: "argdescriptorconstant" - inputFloatName: "min" - inputFloatName: "max" - inputTensorName: "min" - inputTensorName: "max" - ruleType: "attribute" - transformerArgs { - key: "value" - transformerArgs { - name: "min" - argType: DOUBLE - } - transformerArgs { - name: "max" - doubleValue: 1.0 - argType: DOUBLE - argIndex: 1 - } - transformerArgs { - name: "min" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 1 - } - transformerArgs { - name: "max" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 2 - } - } - transformerArgs { - key: "value" - transformerArgs { - name: "min" - argType: DOUBLE - } - transformerArgs { - name: "max" - doubleValue: 1.0 - argType: DOUBLE - argIndex: 1 - } - transformerArgs { - name: "min" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 1 - } - transformerArgs { - name: "max" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 2 - } - } - transformerArgs { - key: "value" - transformerArgs { - name: "min" - argType: DOUBLE - } - transformerArgs { - name: "max" - doubleValue: 1.0 - argType: DOUBLE - argIndex: 1 - } - transformerArgs { - name: "min" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 1 - } - transformerArgs { - name: "max" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 2 - } - } - transformerArgs { - key: "value" - transformerArgs { - name: "min" - argType: DOUBLE - } - transformerArgs { - name: "max" - doubleValue: 1.0 - argType: DOUBLE - argIndex: 1 - } - transformerArgs { - name: "min" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 1 - } - transformerArgs { - name: "max" - inputValue { - data_type: 11 - double_data: 1.0 - } - argType: INPUT_TENSOR - argIndex: 2 - } - } - inputFrameworkOpName: "RandomUniform" - } } mappings { frameworkName: "tensorflow" @@ -6847,8 +6651,10 @@ mappings { functionName: "ndarraymapping" inputTensorName: "input" inputTensorName: "axis" + inputTensorName: "shift" outputTensorName: "input" outputTensorName: "dimensions" + outputTensorName: "shiftsI" inputToOutput { key: "input" value: "input" @@ -6857,6 +6663,10 @@ mappings { key: "dimensions" value: "axis" } + inputToOutput { + key: "shiftsI" + value: "shift" + } ruleType: "tensor" inputFrameworkOpName: "Roll" } @@ -6972,8 +6782,10 @@ mappings { functionName: "ndarraymapping" inputTensorName: "data" inputTensorName: "segment_ids" + inputTensorName: "num_segments" outputTensorName: "input" outputTensorName: "idxSegments" + outputTensorName: "numSegments" inputToOutput { key: "input" value: "data" @@ -6982,13 +6794,16 @@ mappings { key: "idxSegments" value: "segment_ids" } + inputToOutput { + key: "numSegments" + value: "num_segments" + } ruleType: "tensor" inputFrameworkOpName: "UnsortedSegmentMin" } rule { ruleName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute" - outputIntName: "numSegments" inputToOutput { key: "numSegments" value: "num_segments" @@ -7239,7 +7054,6 @@ mappings { rule { ruleName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute" - outputDoubleName: "start" outputDoubleName: "stop" inputToOutput { key: "start" @@ -7255,8 +7069,8 @@ mappings { rule { ruleName: "valuemapping" functionName: "valuemapping" + outputIntName: "dataType" inputDataTypeName: "T" - outputDataTypeName: "dataType" inputToOutput { key: "dataType" value: "T" @@ -7380,8 +7194,10 @@ mappings { functionName: "ndarraymapping" inputTensorName: "data" inputTensorName: "segment_ids" + inputTensorName: "num_segments" outputTensorName: "input" outputTensorName: "idxSegments" + outputTensorName: "numSegments" inputToOutput { key: "input" value: "data" @@ -7390,13 +7206,16 @@ mappings { key: "idxSegments" value: "segment_ids" } + inputToOutput { + key: "numSegments" + value: "num_segments" + } ruleType: "tensor" inputFrameworkOpName: "UnsortedSegmentSum" } rule { ruleName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute" - outputIntName: "numSegments" inputToOutput { key: "numSegments" value: "num_segments" @@ -9065,7 +8884,7 @@ mappings { inputTensorName: "max_output_size" outputTensorName: "boxes" outputTensorName: "scales" - outputTensorName: "iouThreshold" + outputTensorName: "overlayThreshold" outputTensorName: "maxOutputSize" inputToOutput { key: "boxes" @@ -9076,7 +8895,7 @@ mappings { value: "scores" } inputToOutput { - key: "iouThreshold" + key: "overlayThreshold" value: "iou_threshold" } inputToOutput { @@ -9184,8 +9003,6 @@ mappings { rule { ruleName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute" - outputDoubleName: "on" - outputDoubleName: "off" inputToOutput { key: "on" value: "on_value" @@ -9298,6 +9115,41 @@ mappings { inputFrameworkOpName: "Square" } } +mappings { + frameworkName: "tensorflow" + opName: "compare_and_bitpack" + inputFrameworkOpName: "CompareAndBitpack" + rule { + ruleName: "ndarraymapping" + functionName: "ndarraymapping" + inputTensorName: "input" + inputTensorName: "threshold" + outputTensorName: "input" + outputTensorName: "y" + inputToOutput { + key: "input" + value: "input" + } + inputToOutput { + key: "y" + value: "threshold" + } + ruleType: "tensor" + inputFrameworkOpName: "CompareAndBitpack" + } + rule { + ruleName: "valuemapping" + functionName: "valuemapping" + inputDataTypeName: "T" + outputDataTypeName: "dtype" + inputToOutput { + key: "dtype" + value: "T" + } + ruleType: "attribute" + inputFrameworkOpName: "CompareAndBitpack" + } +} mappings { frameworkName: "tensorflow" opName: "segment_min" @@ -9353,8 +9205,10 @@ mappings { functionName: "ndarraymapping" inputTensorName: "data" inputTensorName: "segment_ids" + inputTensorName: "num_segments" outputTensorName: "input" outputTensorName: "idxSegments" + outputTensorName: "numSegments" inputToOutput { key: "input" value: "data" @@ -9363,13 +9217,16 @@ mappings { key: "idxSegments" value: "segment_ids" } + inputToOutput { + key: "numSegments" + value: "num_segments" + } ruleType: "tensor" inputFrameworkOpName: "UnsortedSegmentMax" } rule { ruleName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute" - outputIntName: "numSegments" inputToOutput { key: "numSegments" value: "num_segments" @@ -9429,13 +9286,13 @@ mappings { inputBooleanName: "align_corners" inputBooleanName: "half_pixel_centers" outputBooleanName: "alignCorners" - outputBooleanName: "halfPixelCenter" + outputBooleanName: "halfPixelCenters" inputToOutput { key: "alignCorners" value: "align_corners" } inputToOutput { - key: "halfPixelCenter" + key: "halfPixelCenters" value: "half_pixel_centers" } ruleType: "attribute" @@ -9833,7 +9690,7 @@ mappings { functionName: "valuemapping" inputFloatName: "iou_threshold" inputToOutput { - key: "iouThreshold" + key: "overlayThreshold" value: "iou_threshold" } ruleType: "attribute" @@ -10185,11 +10042,9 @@ mappings { inputTensorName: "weights" inputTensorName: "arr" inputTensorName: "size" - inputTensorName: "size" outputTensorName: "weights" outputTensorName: "values" outputTensorName: "min" - outputTensorName: "max" inputToOutput { key: "weights" value: "weights" @@ -10202,38 +10057,9 @@ mappings { key: "min" value: "size" } - inputToOutput { - key: "max" - value: "size" - } ruleType: "tensor" inputFrameworkOpName: "Bincount" } - rule { - ruleName: "argdescriptorconstant" - functionName: "argdescriptorconstant" - inputIntName: "minLength" - ruleType: "attribute" - transformerArgs { - key: "value" - transformerArgs { - name: "minLength" - argType: INT64 - } - } - inputFrameworkOpName: "Bincount" - } - rule { - ruleName: "ndarrayinputtonumericalattribute" - functionName: "ndarrayinputtonumericalattribute" - outputIntName: "maxLength" - inputToOutput { - key: "maxLength" - value: "size" - } - ruleType: "attribute" - inputFrameworkOpName: "Bincount" - } rule { ruleName: "valuemapping" functionName: "valuemapping" @@ -10246,14 +10072,6 @@ mappings { ruleType: "attribute" inputFrameworkOpName: "Bincount" } - indexOverrides { - key: 1 - value: 2 - } - indexOverrides { - key: 2 - value: 1 - } } mappings { frameworkName: "tensorflow" @@ -10483,31 +10301,29 @@ mappings { ruleName: "ndarraymapping" functionName: "ndarraymapping" inputTensorName: "shape" - inputTensorName: "minval" - inputTensorName: "maxval" outputTensorName: "shape" - outputTensorName: "min" - outputTensorName: "max" inputToOutput { key: "shape" value: "shape" } - inputToOutput { - key: "min" - value: "minval" - } - inputToOutput { - key: "max" - value: "maxval" - } ruleType: "tensor" inputFrameworkOpName: "RandomUniformInt" } + rule { + ruleName: "valuemapping" + functionName: "valuemapping" + inputIntName: "seed" + outputIntName: "seed" + inputToOutput { + key: "seed" + value: "seed" + } + ruleType: "attribute" + inputFrameworkOpName: "RandomUniformInt" + } rule { ruleName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute" - outputDoubleName: "min" - outputDoubleName: "max" inputToOutput { key: "min" value: "minval" @@ -10822,14 +10638,8 @@ mappings { opName: "shapes_of" inputFrameworkOpName: "ShapeN" rule { - ruleName: "ndarraymapping" - functionName: "ndarraymapping" - inputTensorName: "input" - outputTensorName: "input" - inputToOutput { - key: "input" - value: "input" - } + ruleName: "passthrough" + functionName: "passthrough" ruleType: "tensor" inputFrameworkOpName: "ShapeN" } @@ -10943,8 +10753,10 @@ mappings { functionName: "ndarraymapping" inputTensorName: "input_sizes" inputTensorName: "filter" + inputTensorName: "out_backprop" outputTensorName: "gradIShape" outputTensorName: "weights" + outputTensorName: "gradO" inputToOutput { key: "gradIShape" value: "input_sizes" @@ -10953,6 +10765,10 @@ mappings { key: "weights" value: "filter" } + inputToOutput { + key: "gradO" + value: "out_backprop" + } ruleType: "tensor" inputFrameworkOpName: "Conv2DBackpropInput" } @@ -11629,6 +11445,18 @@ mappings { } inputFrameworkOpName: "CopyHost" } + rule { + ruleName: "valuemapping" + functionName: "valuemapping" + inputDataTypeName: "T" + outputDataTypeName: "dataType" + inputToOutput { + key: "dataType" + value: "T" + } + ruleType: "attribute" + inputFrameworkOpName: "CopyHost" + } } mappings { frameworkName: "tensorflow" @@ -12011,11 +11839,17 @@ mappings { ruleName: "ndarraymapping" functionName: "ndarraymapping" inputTensorName: "dims" - outputTensorName: "shapeArray" + inputTensorName: "value" + outputTensorName: "shape" + outputTensorName: "outputs" inputToOutput { - key: "shapeArray" + key: "shape" value: "dims" } + inputToOutput { + key: "outputs" + value: "value" + } ruleType: "tensor" inputFrameworkOpName: "Fill" } @@ -12030,18 +11864,6 @@ mappings { ruleType: "attribute" inputFrameworkOpName: "Fill" } - rule { - ruleName: "datatypetoint" - functionName: "datatypetoint" - outputIntName: "dtype" - inputDataTypeName: "T" - inputToOutput { - key: "dtype" - value: "T" - } - ruleType: "attribute" - inputFrameworkOpName: "Fill" - } rule { ruleName: "valuemapping" functionName: "valuemapping" @@ -12306,11 +12128,11 @@ mappings { rule { ruleName: "valuemapping" functionName: "valuemapping" - inputDataTypeName: "T" + inputDataTypeName: "Targmax" outputDataTypeName: "dtype" inputToOutput { key: "dtype" - value: "T" + value: "Targmax" } ruleType: "attribute" inputFrameworkOpName: "MaxPoolWithArgmax" @@ -13288,14 +13110,8 @@ mappings { opName: "identity_n" inputFrameworkOpName: "IdentityN" rule { - ruleName: "ndarraymapping" - functionName: "ndarraymapping" - inputTensorName: "input" - outputTensorName: "input" - inputToOutput { - key: "input" - value: "input" - } + ruleName: "passthrough" + functionName: "passthrough" ruleType: "tensor" inputFrameworkOpName: "IdentityN" } @@ -13379,9 +13195,6 @@ mappings { rule { ruleName: "ndarrayinputtonumericalattribute" functionName: "ndarrayinputtonumericalattribute" - outputDoubleName: "from" - outputDoubleName: "to" - outputDoubleName: "step" inputToOutput { key: "from" value: "start" @@ -14760,20 +14573,8 @@ mappings { opName: "concat" inputFrameworkOpName: "ConcatV2" rule { - ruleName: "multiinputindex" - functionName: "multiinputindex" - inputTensorName: "values" - inputTensorName: "axis" - outputTensorName: "input" - outputTensorName: "concatDimension" - inputToOutput { - key: "input" - value: "values" - } - inputToOutput { - key: "concatDimension" - value: "axis" - } + ruleName: "passthrough" + functionName: "passthrough" ruleType: "tensor" inputFrameworkOpName: "ConcatV2" } @@ -15641,6 +15442,18 @@ mappings { } inputFrameworkOpName: "DeepCopy" } + rule { + ruleName: "valuemapping" + functionName: "valuemapping" + inputDataTypeName: "T" + outputDataTypeName: "dataType" + inputToOutput { + key: "dataType" + value: "T" + } + ruleType: "attribute" + inputFrameworkOpName: "DeepCopy" + } } mappings { frameworkName: "tensorflow"